Skip to content

Commit 84c4228

Browse files
committed
fixup! [MLIR][Vector] Update Transfer{Read|Write}DropUnitDimsPattern patterns
Bail out for with 0D cases
1 parent 0de9b8c commit 84c4228

File tree

3 files changed

+52
-2
lines changed

3 files changed

+52
-2
lines changed

mlir/lib/Dialect/Vector/Transforms/VectorTransferOpTransforms.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,10 @@ class TransferReadDropUnitDimsPattern
378378
int reducedRank = getReducedRank(sourceType.getShape());
379379
if (reducedRank == sourceType.getRank())
380380
return failure();
381+
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
382+
// out.
383+
if (reducedRank == 0 && maskingOp)
384+
return failure();
381385
// Check if the reduced vector shape matches the reduced source shape.
382386
// Otherwise, this case is not supported yet.
383387
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -415,7 +419,7 @@ class TransferReadDropUnitDimsPattern
415419

416420
if (maskingOp) {
417421
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
418-
loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
422+
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
419423
maskingOp.getMask());
420424
newTransferReadOp = mlir::vector::maskOperation(
421425
rewriter, newTransferReadOp, shapeCastMask);
@@ -456,6 +460,10 @@ class TransferWriteDropUnitDimsPattern
456460
int reducedRank = getReducedRank(sourceType.getShape());
457461
if (reducedRank == sourceType.getRank())
458462
return failure();
463+
// TODO: Extend vector.mask to support 0-d vectors. In the meantime, bail
464+
// out.
465+
if (reducedRank == 0 && maskingOp)
466+
return failure();
459467
// Check if the reduced vector shape matches the reduced destination shape.
460468
// Otherwise, this case is not supported yet.
461469
VectorType reducedVectorType = trimNonScalableUnitDims(vectorType);
@@ -494,7 +502,7 @@ class TransferWriteDropUnitDimsPattern
494502

495503
if (maskingOp) {
496504
auto shapeCastMask = rewriter.createOrFold<vector::ShapeCastOp>(
497-
loc, reducedVectorType.cloneWith({}, rewriter.getI1Type()),
505+
loc, reducedVectorType.cloneWith(std::nullopt, rewriter.getI1Type()),
498506
maskingOp.getMask());
499507
newXferWrite =
500508
mlir::vector::maskOperation(rewriter, newXferWrite, shapeCastMask);

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1717,6 +1717,15 @@ func.func @vector_mask_shape_mismatch(%a: vector<8xi32>, %m0: vector<16xi1>) ->
17171717

17181718
// -----
17191719

1720+
func.func @vector_mask_passthru_type_mismatch(%t0: tensor<f32>, %m0: vector<i1>) -> vector<f32> {
1721+
%ft0 = arith.constant 0.0 : f32
1722+
// expected-error@+1 {{'vector.mask' op operand #0 must be vector of 1-bit signless integer values, but got 'vector<i1>'}}
1723+
%0 = vector.mask %m0 { vector.transfer_read %t0[], %ft0 : tensor<f32>, vector<f32> } : vector<i1> -> vector<f32>
1724+
return %0 : vector<f32>
1725+
}
1726+
1727+
// -----
1728+
17201729
// expected-note@+1 {{prior use here}}
17211730
func.func @vector_mask_passthru_type_mismatch(%t0: tensor<?xf32>, %idx: index, %m0: vector<16xi1>, %pt0: vector<16xi32>) -> vector<16xf32> {
17221731
%ft0 = arith.constant 0.0 : f32

mlir/test/Dialect/Vector/vector-transfer-drop-unit-dims-patterns.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,22 @@ func.func @transfer_read_and_vector_rank_reducing_to_0d(
114114
// CHECK: %[[READ:.+]] = vector.transfer_read %[[SUBVIEW]]{{.*}} : memref<f32>, vector<f32>
115115
// CHECK: vector.shape_cast %[[READ]] : vector<f32> to vector<1x1x1xf32>
116116

117+
func.func @transfer_read_and_vector_rank_reducing_to_0d_masked(
118+
%arg : memref<1x1x1x1x1xf32>,
119+
%mask: vector<1x1x1xi1>) -> vector<1x1x1xf32> {
120+
121+
%c0 = arith.constant 0 : index
122+
%cst = arith.constant 0.0 : f32
123+
%v = vector.mask %mask {
124+
vector.transfer_read %arg[%c0, %c0, %c0, %c0, %c0], %cst
125+
: memref<1x1x1x1x1xf32>, vector<1x1x1xf32>
126+
} : vector<1x1x1xi1> -> vector<1x1x1xf32>
127+
return %v : vector<1x1x1xf32>
128+
}
129+
// CHECK-LABEL: func @transfer_read_and_vector_rank_reducing_to_0d_masked
130+
// CHECK-NOT: vector.shape_cast
131+
// CHECK-NOT: memref.subview
132+
117133
func.func @transfer_write_and_vector_rank_reducing_to_0d(
118134
%arg : memref<1x1x1x1x1xf32>,
119135
%vec : vector<1x1x1xf32>) {
@@ -128,6 +144,23 @@ func.func @transfer_write_and_vector_rank_reducing_to_0d(
128144
// CHECK: %[[SHCAST:.+]] = vector.shape_cast %[[VECTOR]] : vector<1x1x1xf32> to vector<f32>
129145
// CHECK: vector.transfer_write %[[SHCAST]], %[[SUBVIEW]]{{.*}} : vector<f32>, memref<f32>
130146

147+
func.func @transfer_write_and_vector_rank_reducing_to_0d_masked(
148+
%arg : memref<1x1x1x1x1xf32>,
149+
%vec : vector<1x1x1xf32>,
150+
%mask: vector<1x1x1xi1>) {
151+
152+
%c0 = arith.constant 0 : index
153+
%cst = arith.constant 0.0 : f32
154+
vector.mask %mask {
155+
vector.transfer_write %vec, %arg[%c0, %c0, %c0, %c0, %c0] :
156+
vector<1x1x1xf32>, memref<1x1x1x1x1xf32>
157+
} : vector<1x1x1xi1>
158+
return
159+
}
160+
// CHECK-LABEL: func @transfer_write_and_vector_rank_reducing_to_0d_masked
161+
// CHECK-NOT: vector.shape_cast
162+
// CHECK-NOT: memref.subview
163+
131164
func.func @transfer_read_dynamic_rank_reducing(
132165
%arg : memref<?x1xi8, strided<[?, ?], offset: ?>>) -> vector<[16]x1xi8> {
133166
%c0 = arith.constant 0 : index

0 commit comments

Comments
 (0)