Skip to content

Commit fa399f7

Browse files
sakupan102github-actions[bot]
authored andcommitted
Automerge: [MLIR] [Vector] Fix canonicalization for vector.scatter with tensor output (#168824)
Commit llvm/llvm-project@7e7ea9c added tensor support for scatter, but running the existing canonicalization on tensors causes bugs, so we fix the canonicalization with tensor output. Closes llvm/llvm-project#168695 --------- Signed-off-by: Ryutaro Okada <[email protected]>
2 parents 94e2ac4 + 04b1975 commit fa399f7

File tree

2 files changed

+64
-1
lines changed

2 files changed

+64
-1
lines changed

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6100,11 +6100,22 @@ class ScatterFolder final : public OpRewritePattern<ScatterOp> {
61006100
using Base::Base;
61016101
LogicalResult matchAndRewrite(ScatterOp scatter,
61026102
PatternRewriter &rewriter) const override {
6103+
ShapedType baseType = scatter.getBaseType();
6104+
bool isMemRef = isa<MemRefType>(baseType);
6105+
if (!isMemRef && !isa<RankedTensorType>(baseType))
6106+
return failure();
6107+
6108+
// Memrefs have no result, so an all-false mask can simply erase the op.
6109+
// Tensors carry the updated value, so we must replace uses with the
6110+
// original base tensor instead of erasing.
61036111
switch (getMaskFormat(scatter.getMask())) {
61046112
case MaskFormat::AllTrue:
61056113
return failure(); // no unmasked equivalent
61066114
case MaskFormat::AllFalse:
6107-
rewriter.eraseOp(scatter);
6115+
if (isMemRef)
6116+
rewriter.eraseOp(scatter);
6117+
else
6118+
rewriter.replaceOp(scatter, scatter.getBase());
61086119
return success();
61096120
case MaskFormat::Unknown:
61106121
return failure();
@@ -6120,6 +6131,11 @@ class FoldContiguousScatter final : public OpRewritePattern<ScatterOp> {
61206131
using Base::Base;
61216132
LogicalResult matchAndRewrite(ScatterOp op,
61226133
PatternRewriter &rewriter) const override {
6134+
// Fold only for memrefs: the replacement uses maskedstore, which does not
6135+
// support tensor bases. Tensor cases intentionally bail out.
6136+
if (!isa<MemRefType>(op.getBase().getType()))
6137+
return failure();
6138+
61236139
if (failed(isZeroBasedContiguousSeq(op.getIndices())))
61246140
return failure();
61256141

mlir/test/Dialect/Vector/canonicalize.mlir

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3928,6 +3928,53 @@ func.func @contiguous_scatter_step(%base: memref<?xf32>,
39283928

39293929
// -----
39303930

3931+
// No canoniclization should happen here as the base is a tensor.
3932+
// CHECK-LABEL: @no_fold_contiguous_scatter_tensor
3933+
// CHECK-NOT: vector.maskedstore
3934+
// CHECK: %[[RES:.*]] = vector.scatter
3935+
// CHECK: return %[[RES]]
3936+
func.func @no_fold_contiguous_scatter_tensor(%base: tensor<16xf32>,
3937+
%mask: vector<16xi1>,
3938+
%value: vector<16xf32>) -> tensor<16xf32> {
3939+
%c0 = arith.constant 0 : index
3940+
%indices = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15]> : vector<16xi32>
3941+
%0 = vector.scatter %base[%c0] [%indices], %mask, %value
3942+
: tensor<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
3943+
return %0 : tensor<16xf32>
3944+
}
3945+
3946+
// -----
3947+
3948+
// CHECK-LABEL: @scatter_memref_all_false
3949+
// CHECK-SAME: (%[[BASE:.*]]: memref<?xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>)
3950+
// CHECK-NEXT: return
3951+
func.func @scatter_memref_all_false(%base: memref<?xf32>,
3952+
%index: vector<16xindex>,
3953+
%value: vector<16xf32>) {
3954+
%c0 = arith.constant 0 : index
3955+
%mask = arith.constant dense<false> : vector<16xi1>
3956+
vector.scatter %base[%c0][%index], %mask, %value
3957+
: memref<?xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32>
3958+
return
3959+
}
3960+
3961+
// -----
3962+
3963+
// CHECK-LABEL: @scatter_tensor_all_false
3964+
// CHECK-SAME: (%[[BASE:.*]]: tensor<16xf32>, %[[INDEX:.*]]: vector<16xindex>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16xf32> {
3965+
// CHECK: return %[[BASE]] : tensor<16xf32>
3966+
func.func @scatter_tensor_all_false(%base: tensor<16xf32>,
3967+
%index: vector<16xindex>,
3968+
%value: vector<16xf32>) -> tensor<16xf32> {
3969+
%c0 = arith.constant 0 : index
3970+
%mask = arith.constant dense<false> : vector<16xi1>
3971+
%0 = vector.scatter %base[%c0][%index], %mask, %value
3972+
: tensor<16xf32>, vector<16xindex>, vector<16xi1>, vector<16xf32> -> tensor<16xf32>
3973+
return %0 : tensor<16xf32>
3974+
}
3975+
3976+
// -----
3977+
39313978
// CHECK-LABEL: @fold_extract_constant_indices
39323979
// CHECK-SAME: %[[ARG:.*]]: vector<32x1xi32>) -> i32 {
39333980
// CHECK: %[[RES:.*]] = vector.extract %[[ARG]][0, 0] : i32 from vector<32x1xi32>

0 commit comments

Comments
 (0)