Skip to content

Commit 2e850c3

Browse files
committed
Fix according to comments.
1 parent 6edd712 commit 2e850c3

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

mlir/test/Dialect/Vector/flatten-memref-and-emulate-narrow-types.mlir

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,56 +9,56 @@
99
//
1010
// TODO: Support `vector.transfer_write` operation.
1111

12-
func.func @vector_load_2d_i4(%arg0: index, %arg1: index) -> vector<8xi4> {
12+
func.func @vector_load_2d_i4(%arg0: index) -> vector<8xi4> {
1313
%0 = memref.alloc() : memref<4x8xi4>
14-
%1 = vector.load %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
14+
%1 = vector.load %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4>
1515
return %1 : vector<8xi4>
1616
}
17-
// CHECK: func @vector_load_2d_i4
18-
// CHECK: vector.load {{.*}} memref<16xi8>
17+
// CHECK-LABEL: func @vector_load_2d_i4
18+
// CHECK: vector.load {{.*}} memref<16xi8>
1919

2020
// -----
2121

22-
func.func @vector_maskedload_2d_i4(%arg0: index, %arg1: index, %passthru: vector<8xi4>) -> vector<8xi4> {
22+
func.func @vector_maskedload_2d_i4(%arg0: index, %passthru: vector<8xi4>) -> vector<8xi4> {
2323
%0 = memref.alloc() : memref<4x8xi4>
2424
%mask = vector.constant_mask [6] : vector<8xi1>
25-
%1 = vector.maskedload %0[%arg0, %arg1], %mask, %passthru :
25+
%1 = vector.maskedload %0[%arg0, %arg0], %mask, %passthru :
2626
memref<4x8xi4>, vector<8xi1>, vector<8xi4> into vector<8xi4>
2727
return %1 : vector<8xi4>
2828
}
29-
// CHECK: func @vector_maskedload_2d_i4(
30-
// CHECK: vector.maskedload {{.*}} memref<16xi8>
29+
// CHECK-LABEL: func @vector_maskedload_2d_i4(
30+
// CHECK: vector.maskedload {{.*}} memref<16xi8>
3131

3232
// -----
3333

34-
func.func @vector_maskedstore_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
34+
func.func @vector_maskedstore_2d_i4(%arg0: index, %value: vector<8xi4>) {
3535
%0 = memref.alloc() : memref<4x8xi4>
3636
%mask = vector.constant_mask [5] : vector<8xi1>
37-
vector.maskedstore %0[%arg0, %arg1], %mask, %value :
37+
vector.maskedstore %0[%arg0, %arg0], %mask, %value :
3838
memref<4x8xi4>, vector<8xi1>, vector<8xi4>
3939
return
4040
}
41-
// CHECK: func @vector_maskedstore_2d_i4(
42-
// CHECK: vector.maskedstore {{.*}} memref<16xi8>
41+
// CHECK-LABEL: func @vector_maskedstore_2d_i4(
42+
// CHECK: vector.maskedstore {{.*}} memref<16xi8>
4343

4444
// -----
4545

46-
func.func @vector_store_2d_i4(%arg0: index, %arg1: index, %value: vector<8xi4>) {
46+
func.func @vector_store_2d_i4(%arg0: index, %value: vector<8xi4>) {
4747
%0 = memref.alloc() : memref<4x8xi4>
48-
vector.store %value, %0[%arg0, %arg1] : memref<4x8xi4>, vector<8xi4>
48+
vector.store %value, %0[%arg0, %arg0] : memref<4x8xi4>, vector<8xi4>
4949
return
5050
}
51-
// CHECK: func @vector_store_2d_i4(
52-
// CHECK: vector.store {{.*}} memref<16xi8>
51+
// CHECK-LABEL: func @vector_store_2d_i4(
52+
// CHECK: vector.store {{.*}} memref<16xi8>
5353

5454
// -----
5555

56-
func.func @vector_transfer_read_2d_i4(%arg0: index, %arg1: index, %padding: i4) -> vector<8xi4> {
56+
func.func @vector_transfer_read_2d_i4(%arg0: index, %padding: i4) -> vector<8xi4> {
5757
%0 = memref.alloc() : memref<4x8xi4>
58-
%1 = vector.transfer_read %0[%arg0, %arg1], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4>
58+
%1 = vector.transfer_read %0[%arg0, %arg0], %padding {in_bounds = [true]} : memref<4x8xi4>, vector<8xi4>
5959
return %1 : vector<8xi4>
6060
}
61-
// CHECK: func @vector_transfer_read_2d_i4(
62-
// CHECK-SAME: %{{.*}}: index, %{{.*}}: index, %[[PADDING_I4:.*]]: i4)
63-
// CHECK: %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8
64-
// CHECK: vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8>
61+
// CHECK-LABEL: func @vector_transfer_read_2d_i4(
62+
// CHECK-SAME: %{{.*}}: index, %[[PADDING_I4:.*]]: i4)
63+
// CHECK: %[[PADDING_I8:.*]] = arith.extui %[[PADDING_I4]] : i4 to i8
64+
// CHECK: vector.transfer_read {{.*}}, %[[PADDING_I8]] : memref<16xi8>, vector<4xi8>

mlir/test/lib/Dialect/MemRef/TestEmulateNarrowType.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,10 @@ struct TestMemRefFlattenAndVectorNarrowTypeEmulationPass
177177

178178
RewritePatternSet patterns(ctx);
179179

180+
// This is necessary for the purpose of emulating `memref.alloc` and
181+
// function boundaries.
180182
memref::populateMemRefNarrowTypeEmulationPatterns(typeConverter, patterns);
183+
181184
vector::populateMemRefFlattenAndVectorNarrowTypeEmulationPatterns(
182185
typeConverter, patterns);
183186

0 commit comments

Comments
 (0)