Skip to content

Commit 7e7ea9c

Browse files
authored
[MLIR] Extend vector.scatter to accept tensor as base (#165548)
This PR makes the following improvements to `vector.scatter` and its lowering pipeline: - In addition to `memref`, accept a ranked `tensor` as the base operand of `vector.scatter`, similar to `vector.transfer_write`. - Implement bufferization support for `vector.scatter`, so that tensor-based scatter ops can be fully lowered to memref-based forms. It's worth to complete the functionality of map_scatter decomposition. Full discussion can be found here: iree-org/iree#21135 --------- Signed-off-by: Ryutaro Okada <[email protected]>
1 parent a407d02 commit 7e7ea9c

File tree

8 files changed

+122
-38
lines changed

8 files changed

+122
-38
lines changed

mlir/include/mlir/Dialect/Vector/IR/VectorOps.td

Lines changed: 25 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -2160,25 +2160,25 @@ def Vector_GatherOp :
21602160
];
21612161
}
21622162

2163-
def Vector_ScatterOp :
2164-
Vector_Op<"scatter", [
2165-
DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2166-
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>
2167-
]>,
2168-
Arguments<(ins Arg<AnyMemRef, "", [MemWrite]>:$base,
2169-
Variadic<Index>:$offsets,
2170-
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
2171-
VectorOfNonZeroRankOf<[I1]>:$mask,
2172-
AnyVectorOfNonZeroRank:$valueToStore,
2173-
OptionalAttr<IntValidAlignment<I64Attr>>: $alignment)> {
2163+
def Vector_ScatterOp
2164+
: Vector_Op<"scatter",
2165+
[DeclareOpInterfaceMethods<MemorySpaceCastConsumerOpInterface>,
2166+
DeclareOpInterfaceMethods<AlignmentAttrOpInterface>]>,
2167+
Arguments<(ins Arg<TensorOrMemRef<[AnyType]>, "", [MemWrite]>:$base,
2168+
Variadic<Index>:$offsets,
2169+
VectorOfNonZeroRankOf<[AnyInteger, Index]>:$indices,
2170+
VectorOfNonZeroRankOf<[I1]>:$mask,
2171+
AnyVectorOfNonZeroRank:$valueToStore,
2172+
OptionalAttr<IntValidAlignment<I64Attr>>:$alignment)>,
2173+
Results<(outs Optional<AnyRankedTensor>:$result)> {
21742174

21752175
let summary = [{
2176-
scatters elements from a vector into memory as defined by an index vector
2176+
scatters elements from a vector into memory or ranked tensor as defined by an index vector
21772177
and a mask vector
21782178
}];
21792179

21802180
let description = [{
2181-
The scatter operation stores elements from a n-D vector into memory as
2181+
The scatter operation stores elements from a n-D vector into memory or ranked tensor as
21822182
defined by a base with indices and an additional n-D index vector, but
21832183
only if the corresponding bit in a n-D mask vector is set. Otherwise, no
21842184
action is taken for that element. Informally the semantics are:
@@ -2221,31 +2221,28 @@ def Vector_ScatterOp :
22212221
}];
22222222

22232223
let extraClassDeclaration = [{
2224-
MemRefType getMemRefType() { return getBase().getType(); }
2224+
ShapedType getBaseType() { return getBase().getType(); }
22252225
VectorType getIndexVectorType() { return getIndices().getType(); }
22262226
VectorType getMaskVectorType() { return getMask().getType(); }
22272227
VectorType getVectorType() { return getValueToStore().getType(); }
22282228
}];
22292229

2230-
let assemblyFormat =
2231-
"$base `[` $offsets `]` `[` $indices `]` `,` "
2232-
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
2233-
"type($indices) `,` type($mask) `,` type($valueToStore)";
2230+
let assemblyFormat = "$base `[` $offsets `]` `[` $indices `]` `,` "
2231+
"$mask `,` $valueToStore attr-dict `:` type($base) `,` "
2232+
"type($indices) `,` type($mask) `,` "
2233+
"type($valueToStore) (`->` type($result)^)?";
22342234
let hasCanonicalizer = 1;
22352235
let hasVerifier = 1;
22362236

2237-
let builders = [
2238-
OpBuilder<(ins "Value":$base,
2239-
"ValueRange":$indices,
2240-
"Value":$index_vec,
2241-
"Value":$mask,
2242-
"Value":$valueToStore,
2243-
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">: $alignment), [{
2244-
return build($_builder, $_state, base, indices, index_vec, mask, valueToStore,
2237+
let builders = [OpBuilder<
2238+
(ins "Type":$resultType, "Value":$base, "ValueRange":$indices,
2239+
"Value":$index_vec, "Value":$mask, "Value":$valueToStore,
2240+
CArg<"llvm::MaybeAlign", "llvm::MaybeAlign()">:$alignment),
2241+
[{
2242+
return build($_builder, $_state, resultType, base, indices, index_vec, mask, valueToStore,
22452243
alignment.has_value() ? $_builder.getI64IntegerAttr(alignment->value()) :
22462244
nullptr);
2247-
}]>
2248-
];
2245+
}]>];
22492246
}
22502247

22512248
def Vector_ExpandLoadOp :

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,8 @@ class VectorScatterOpConversion
345345
matchAndRewrite(vector::ScatterOp scatter, OpAdaptor adaptor,
346346
ConversionPatternRewriter &rewriter) const override {
347347
auto loc = scatter->getLoc();
348-
MemRefType memRefType = scatter.getMemRefType();
348+
auto memRefType = dyn_cast<MemRefType>(scatter.getBaseType());
349+
assert(memRefType && "The base should be bufferized");
349350

350351
if (failed(isMemRefTypeSupported(memRefType, *this->getTypeConverter())))
351352
return rewriter.notifyMatchFailure(scatter, "memref type not supported");

mlir/lib/Dialect/SparseTensor/Transforms/SparseVectorization.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,8 @@ static void genVectorStore(PatternRewriter &rewriter, Location loc, Value mem,
132132
SmallVector<Value> scalarArgs(idxs);
133133
Value indexVec = idxs.back();
134134
scalarArgs.back() = constantIndex(rewriter, loc, 0);
135-
vector::ScatterOp::create(rewriter, loc, mem, scalarArgs, indexVec, vmask,
136-
rhs);
135+
vector::ScatterOp::create(rewriter, loc, /*resultType=*/nullptr, mem,
136+
scalarArgs, indexVec, vmask, rhs);
137137
return;
138138
}
139139
vector::MaskedStoreOp::create(rewriter, loc, mem, idxs, vmask, rhs);

mlir/lib/Dialect/Vector/IR/VectorOps.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6066,19 +6066,21 @@ LogicalResult ScatterOp::verify() {
60666066
VectorType indVType = getIndexVectorType();
60676067
VectorType maskVType = getMaskVectorType();
60686068
VectorType valueVType = getVectorType();
6069-
MemRefType memType = getMemRefType();
6069+
ShapedType baseType = getBaseType();
60706070

6071-
if (valueVType.getElementType() != memType.getElementType())
6071+
if (!llvm::isa<MemRefType, RankedTensorType>(baseType))
6072+
return emitOpError("requires base to be a memref or ranked tensor type");
6073+
6074+
if (valueVType.getElementType() != baseType.getElementType())
60726075
return emitOpError("base and valueToStore element type should match");
6073-
if (llvm::size(getOffsets()) != memType.getRank())
6074-
return emitOpError("requires ") << memType.getRank() << " indices";
6076+
if (llvm::size(getOffsets()) != baseType.getRank())
6077+
return emitOpError("requires ") << baseType.getRank() << " indices";
60756078
if (valueVType.getShape() != indVType.getShape())
60766079
return emitOpError("expected valueToStore dim to match indices dim");
60776080
if (valueVType.getShape() != maskVType.getShape())
60786081
return emitOpError("expected valueToStore dim to match mask dim");
60796082
return success();
60806083
}
6081-
60826084
namespace {
60836085
class ScatterFolder final : public OpRewritePattern<ScatterOp> {
60846086
public:

mlir/lib/Dialect/Vector/Transforms/BufferizableOpInterfaceImpl.cpp

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
#include "mlir/Dialect/Vector/IR/VectorOps.h"
1515
#include "mlir/IR/Dialect.h"
1616
#include "mlir/IR/Operation.h"
17+
#include "mlir/IR/Value.h"
1718

1819
using namespace mlir;
1920
using namespace mlir::bufferization;
@@ -126,6 +127,54 @@ struct TransferWriteOpInterface
126127
}
127128
};
128129

130+
/// Bufferization of vector.scatter. Replaced with a new vector.scatter that
131+
/// operates on a memref.
132+
struct ScatterOpInterface
133+
: public BufferizableOpInterface::ExternalModel<ScatterOpInterface,
134+
vector::ScatterOp> {
135+
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
136+
const AnalysisState &state) const {
137+
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
138+
"only tensor types expected");
139+
return true;
140+
}
141+
142+
bool bufferizesToMemoryWrite(Operation *op, OpOperand &opOperand,
143+
const AnalysisState &state) const {
144+
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
145+
"only tensor types expected");
146+
return true;
147+
}
148+
149+
AliasingValueList getAliasingValues(Operation *op, OpOperand &opOperand,
150+
const AnalysisState &state) const {
151+
assert(isa<RankedTensorType>(opOperand.get().getType()) &&
152+
"only tensor types expected");
153+
auto scatterOp = cast<vector::ScatterOp>(op);
154+
if (&opOperand != &scatterOp.getBaseMutable())
155+
return {};
156+
return {{scatterOp.getResult(), BufferRelation::Equivalent}};
157+
}
158+
159+
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
160+
const BufferizationOptions &options,
161+
BufferizationState &state) const {
162+
auto scatterOp = cast<vector::ScatterOp>(op);
163+
assert(isa<TensorType>(scatterOp.getBaseType()) &&
164+
"only tensor types expected");
165+
FailureOr<Value> buffer =
166+
getBuffer(rewriter, scatterOp.getBase(), options, state);
167+
if (failed(buffer))
168+
return failure();
169+
vector::ScatterOp::create(rewriter, scatterOp.getLoc(),
170+
/*resultType=*/nullptr, *buffer,
171+
scatterOp.getOffsets(), scatterOp.getIndices(),
172+
scatterOp.getMask(), scatterOp.getValueToStore());
173+
replaceOpWithBufferizedValues(rewriter, op, *buffer);
174+
return success();
175+
}
176+
};
177+
129178
/// Bufferization of vector.gather. Replaced with a new vector.gather that
130179
/// operates on a memref.
131180
struct GatherOpInterface
@@ -335,5 +384,6 @@ void mlir::vector::registerBufferizableOpInterfaceExternalModels(
335384
GatherOp::attachInterface<GatherOpInterface>(*ctx);
336385
MaskOp::attachInterface<MaskOpInterface>(*ctx);
337386
YieldOp::attachInterface<YieldOpInterface>(*ctx);
387+
ScatterOp::attachInterface<ScatterOpInterface>(*ctx);
338388
});
339389
}

mlir/test/Dialect/Vector/bufferize.mlir

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,26 @@ func.func @transfer_write(%t: tensor<?x?xf32>, %o1: index,
3232

3333
// -----
3434

35+
// CHECK-LABEL: func @scatter(
36+
// CHECK-SAME: %[[base:.*]]: tensor<16x16xf32>, %[[v:.*]]: vector<16xi32>,
37+
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[value:.*]]: vector<16xf32>) -> tensor<16x16xf32>
38+
// CHECK: %[[buf:.*]] = bufferization.to_buffer %[[base]] : tensor<16x16xf32> to memref<16x16xf32>
39+
// CHECK: %[[c0:.*]] = arith.constant 0 : index
40+
// CHECK: %[[alloc:.*]] = memref.alloc() {alignment = 64 : i64} : memref<16x16xf32>
41+
// CHECK: memref.copy %[[buf]], %[[alloc]] : memref<16x16xf32> to memref<16x16xf32>
42+
// CHECK: vector.scatter %[[alloc]][%[[c0]], %[[c0]]] [%[[v]]], %[[mask]], %[[value]] : memref<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
43+
// CHECK: %[[tensor:.*]] = bufferization.to_tensor %[[alloc]] : memref<16x16xf32> to tensor<16x16xf32>
44+
// CHECK: return %[[tensor]] : tensor<16x16xf32>
45+
func.func @scatter(%base: tensor<16x16xf32>, %v: vector<16xi32>,
46+
%mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
47+
%c0 = arith.constant 0 : index
48+
%0 = vector.scatter %base[%c0, %c0][%v], %mask, %value
49+
: tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
50+
return %0 : tensor<16x16xf32>
51+
}
52+
53+
// -----
54+
3555
// CHECK-LABEL: func @gather(
3656
// CHECK-SAME: %[[base:.*]]: tensor<?x?xf32>, %[[v:.*]]: vector<16xi32>,
3757
// CHECK-SAME: %[[mask:.*]]: vector<16xi1>, %[[pass_thru:.*]]: vector<16xf32>)

mlir/test/Dialect/Vector/invalid.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1491,9 +1491,9 @@ func.func @gather_non_power_of_two_alignment(%base: memref<16xf32>, %indices: ve
14911491
func.func @scatter_to_vector(%base: vector<16xf32>, %indices: vector<16xi32>,
14921492
%mask: vector<16xi1>, %pass_thru: vector<16xf32>) {
14931493
%c0 = arith.constant 0 : index
1494-
// expected-error@+2 {{custom op 'vector.scatter' invalid kind of type specified}}
1494+
// expected-error@+1 {{'vector.scatter' op operand #0 must be Tensor or MemRef of any type values, but got 'vector<16xf32>'}}
14951495
vector.scatter %base[%c0][%indices], %mask, %pass_thru
1496-
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> into vector<16xf32>
1496+
: vector<16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32>
14971497
}
14981498

14991499
// -----

mlir/test/Dialect/Vector/ops.mlir

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1160,3 +1160,17 @@ func.func @step() {
11601160
%1 = vector.step : vector<[4]xindex>
11611161
return
11621162
}
1163+
1164+
// CHECK-LABEL: func @scatter_tensor(
1165+
// CHECK-SAME: %[[BASE:.*]]: tensor<16x16xf32>, %[[V:.*]]: vector<16xi32>,
1166+
// CHECK-SAME: %[[MASK:.*]]: vector<16xi1>, %[[VALUE:.*]]: vector<16xf32>) -> tensor<16x16xf32>
1167+
func.func @scatter_tensor(%base: tensor<16x16xf32>, %v: vector<16xi32>,
1168+
%mask: vector<16xi1>, %value: vector<16xf32>) -> tensor<16x16xf32> {
1169+
// CHECK: %[[C0:.*]] = arith.constant 0 : index
1170+
%c0 = arith.constant 0 : index
1171+
// CHECK: %[[RESULT:.*]] = vector.scatter %[[BASE]][%[[C0]], %[[C0]]] [%[[V]]], %[[MASK]], %[[VALUE]]
1172+
%0 = vector.scatter %base[%c0, %c0] [%v], %mask, %value
1173+
: tensor<16x16xf32>, vector<16xi32>, vector<16xi1>, vector<16xf32> -> tensor<16x16xf32>
1174+
// CHECK: return %[[RESULT]] : tensor<16x16xf32>
1175+
return %0 : tensor<16x16xf32>
1176+
}

0 commit comments

Comments
 (0)