Skip to content

Commit 9d91abe

Browse files
committed
Clean up
1 parent 2d8174c commit 9d91abe

File tree

3 files changed

+83
-113
lines changed

3 files changed

+83
-113
lines changed

mlir/test/Dialect/Vector/vector-load-store-unroll.mlir

Lines changed: 0 additions & 73 deletions
This file was deleted.

mlir/test/Dialect/Vector/vector-unroll-options.mlir

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,76 @@ func.func @vector_broadcast_with_tailing_unit_dim(%v: vector<4x1xf32>) -> vector
378378
// CHECK: [[b3:%.+]] = vector.broadcast [[s3]] : vector<2x1xf32> to vector<2x2xf32>
379379
// CHECK: [[r3:%.+]] = vector.insert_strided_slice [[b3]], [[r2]] {offsets = [2, 2], strides = [1, 1]} : vector<2x2xf32> into vector<4x4xf32>
380380
// CHECK: return [[r3]] : vector<4x4xf32>
381+
382+
383+
// CHECK-LABEL: func.func @unroll_2D_vector_load(
384+
// CHECK-SAME: %[[ARG:.*]]: memref<4x4xf16>) -> vector<4x4xf16> {
385+
func.func @unroll_2D_vector_load(%arg0: memref<4x4xf16>) -> vector<4x4xf16> {
386+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
387+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
388+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
389+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
390+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf16>
391+
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
392+
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
393+
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
394+
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
395+
// CHECK: %[[V4:.*]] = vector.load %[[ARG]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
396+
// CHECK: %[[V5:.*]] = vector.insert_strided_slice %[[V4]], %[[V3]] {offsets = [2, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
397+
// CHECK: %[[V6:.*]] = vector.load %[[ARG]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
398+
// CHECK: %[[V7:.*]] = vector.insert_strided_slice %[[V6]], %[[V5]] {offsets = [3, 0], strides = [1]} : vector<4xf16> into vector<4x4xf16>
399+
// CHECK: return %[[V7]] : vector<4x4xf16>
400+
%c0 = arith.constant 0 : index
401+
%0 = vector.load %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
402+
return %0 : vector<4x4xf16>
403+
}
404+
405+
// CHECK-LABEL: func.func @unroll_2D_vector_store(
406+
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4xf16>, %[[ARG1:.*]]: vector<4x4xf16>) {
407+
func.func @unroll_2D_vector_store(%arg0: memref<4x4xf16>, %arg1: vector<4x4xf16>) {
408+
// CHECK: %[[C3:.*]] = arith.constant 3 : index
409+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
410+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
411+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
412+
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<4xf16> from vector<4x4xf16>
413+
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C0]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
414+
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<4xf16> from vector<4x4xf16>
415+
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
416+
// CHECK: %[[V2:.*]] = vector.extract %[[ARG1]][2] : vector<4xf16> from vector<4x4xf16>
417+
// CHECK: vector.store %[[V2]], %[[ARG0]][%[[C2]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
418+
// CHECK: %[[V3:.*]] = vector.extract %[[ARG1]][3] : vector<4xf16> from vector<4x4xf16>
419+
// CHECK: vector.store %[[V3]], %[[ARG0]][%[[C3]], %[[C0]]] : memref<4x4xf16>, vector<4xf16>
420+
%c0 = arith.constant 0 : index
421+
vector.store %arg1, %arg0[%c0, %c0] : memref<4x4xf16>, vector<4x4xf16>
422+
return
423+
}
424+
425+
// CHECK-LABEL: func.func @unroll_vector_load(
426+
// CHECK-SAME: %[[ARG:.*]]: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
427+
func.func @unroll_vector_load(%arg0: memref<4x4x4x4xf16>) -> vector<2x2xf16> {
428+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
429+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
430+
// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2xf16>
431+
// CHECK: %[[V0:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
432+
// CHECK: %[[V1:.*]] = vector.insert_strided_slice %[[V0]], %[[CST]] {offsets = [0, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
433+
// CHECK: %[[V2:.*]] = vector.load %[[ARG]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
434+
// CHECK: %[[V3:.*]] = vector.insert_strided_slice %[[V2]], %[[V1]] {offsets = [1, 0], strides = [1]} : vector<2xf16> into vector<2x2xf16>
435+
// CHECK: return %[[V3]] : vector<2x2xf16>
436+
%c1 = arith.constant 1 : index
437+
%0 = vector.load %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
438+
return %0 : vector<2x2xf16>
439+
}
440+
441+
// CHECK-LABEL: func.func @unroll_vector_store(
442+
// CHECK-SAME: %[[ARG0:.*]]: memref<4x4x4x4xf16>, %[[ARG1:.*]]: vector<2x2xf16>) {
443+
func.func @unroll_vector_store(%arg0: memref<4x4x4x4xf16>, %arg1: vector<2x2xf16>) {
444+
// CHECK: %[[C2:.*]] = arith.constant 2 : index
445+
// CHECK: %[[C1:.*]] = arith.constant 1 : index
446+
// CHECK: %[[V0:.*]] = vector.extract %[[ARG1]][0] : vector<2xf16> from vector<2x2xf16>
447+
// CHECK: vector.store %[[V0]], %[[ARG0]][%[[C1]], %[[C1]], %[[C1]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
448+
// CHECK: %[[V1:.*]] = vector.extract %[[ARG1]][1] : vector<2xf16> from vector<2x2xf16>
449+
// CHECK: vector.store %[[V1]], %[[ARG0]][%[[C1]], %[[C1]], %[[C2]], %[[C1]]] : memref<4x4x4x4xf16>, vector<2xf16>
450+
%c1 = arith.constant 1 : index
451+
vector.store %arg1, %arg0[%c1, %c1, %c1, %c1] : memref<4x4x4x4xf16>, vector<2x2xf16>
452+
return
453+
}

mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp

Lines changed: 10 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,16 @@ struct TestVectorUnrollingPatterns
178178
return success(isa<vector::TransposeOp>(op));
179179
}));
180180

181+
populateVectorUnrollPatterns(
182+
patterns, UnrollVectorOptions()
183+
.setNativeShape(ArrayRef<int64_t>{2, 2})
184+
.setFilterConstraint([](Operation *op) {
185+
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
186+
return success(loadOp.getType().getRank() > 1);
187+
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
188+
return success(storeOp.getVectorType().getRank() > 1);
189+
return failure();
190+
}));
181191
if (unrollBasedOnType) {
182192
UnrollVectorOptions::NativeShapeFnType nativeShapeFn =
183193
[](Operation *op) -> std::optional<SmallVector<int64_t>> {
@@ -292,44 +302,6 @@ struct TestVectorTransferUnrollingPatterns
292302
llvm::cl::init(false)};
293303
};
294304

295-
struct TestVectorLoadStoreUnrollPatterns
296-
: public PassWrapper<TestVectorLoadStoreUnrollPatterns,
297-
OperationPass<func::FuncOp>> {
298-
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
299-
TestVectorLoadStoreUnrollPatterns)
300-
301-
StringRef getArgument() const final {
302-
return "test-vector-load-store-unroll";
303-
}
304-
StringRef getDescription() const final {
305-
return "Test unrolling patterns for vector.load and vector.store ops";
306-
}
307-
308-
void getDependentDialects(DialectRegistry &registry) const override {
309-
registry.insert<vector::VectorDialect, arith::ArithDialect>();
310-
}
311-
312-
void runOnOperation() override {
313-
MLIRContext *ctx = &getContext();
314-
RewritePatternSet patterns(ctx);
315-
316-
// Unroll all vector.load and vector.store ops with rank > 1 to 1D vectors
317-
vector::UnrollVectorOptions options;
318-
options.setFilterConstraint([](Operation *op) {
319-
if (auto loadOp = dyn_cast<vector::LoadOp>(op))
320-
return success(loadOp.getType().getRank() > 1);
321-
if (auto storeOp = dyn_cast<vector::StoreOp>(op))
322-
return success(storeOp.getVectorType().getRank() > 1);
323-
return failure();
324-
});
325-
326-
vector::populateVectorUnrollPatterns(patterns, options);
327-
328-
// Apply the patterns
329-
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
330-
}
331-
};
332-
333305
struct TestScalarVectorTransferLoweringPatterns
334306
: public PassWrapper<TestScalarVectorTransferLoweringPatterns,
335307
OperationPass<func::FuncOp>> {
@@ -1070,8 +1042,6 @@ void registerTestVectorLowerings() {
10701042

10711043
PassRegistration<TestVectorTransferUnrollingPatterns>();
10721044

1073-
PassRegistration<TestVectorLoadStoreUnrollPatterns>();
1074-
10751045
PassRegistration<TestScalarVectorTransferLoweringPatterns>();
10761046

10771047
PassRegistration<TestVectorTransferOpt>();

0 commit comments

Comments
 (0)