Skip to content

Commit 6900621

Browse files
authored
[LinalgExt] Add gather operation (1/5) (iree-org#20460)
Adds `iree_linalg_ext.gather` operation which is the converse of `iree_linalg_ext.scatter`. Both operations share similar semantics, but while scatter writes values into `original`, gather reads values from `source` based on `indices`. --------- Signed-off-by: Ian Wood <[email protected]>
1 parent 7724306 commit 6900621

File tree

4 files changed

+318
-31
lines changed

4 files changed

+318
-31
lines changed

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.cpp

Lines changed: 64 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -133,29 +133,35 @@ static bool isSmallerThan(ArrayRef<int64_t> sourceShape,
133133
});
134134
}
135135

136-
//===----------------------------------------------------------------------===//
137-
// ScatterOp
138-
//===----------------------------------------------------------------------===//
139-
140-
LogicalResult ScatterOp::verify() {
141-
Operation *op = getOperation();
142-
if (getInputs().size() != 2) {
143-
return op->emitOpError("expected two input operands");
136+
/// Helper function to verify both `scatter` and `gather`. Since both ops share
137+
/// the same sementics, we can use the same function to verify them. Note: this
138+
/// is written from the perspective of `scatter` op. For gather, `updateType`
139+
/// maps to the type of the output and `originalType` maps to the type of the
140+
/// `source`.
141+
template <typename OpTy>
142+
static LogicalResult
143+
verifyGatherScatter(OpTy op, int64_t sliceRank, ShapedType originalType,
144+
ShapedType updateType, StringRef originalName,
145+
StringRef updateName) {
146+
static_assert(llvm::is_one_of<OpTy, GatherOp, ScatterOp>::value,
147+
"applies to only gather or scatter operations");
148+
if (op.getInputs().size() != 2) {
149+
return op.emitOpError("expected two input operands");
144150
}
145-
if (getOutputs().size() != 1) {
146-
return op->emitOpError("expected one output operand");
151+
if (op.getOutputs().size() != 1) {
152+
return op.emitOpError("expected one output operand");
147153
}
148154

149-
auto indicesType = getIndicesType();
155+
auto indicesType = op.getIndicesType();
150156
if (indicesType.getRank() < 1 ||
151157
!isa<IntegerType>(indicesType.getElementType())) {
152158
return op->emitOpError("expected indices to be of rank 1 or greater and of "
153159
"integer element type");
154160
}
155161

156-
ArrayRef<int64_t> dimMap = getDimensionMap();
162+
ArrayRef<int64_t> dimMap = op.getDimensionMap();
157163
if (failed(isPermSequence(
158-
[&]() { return this->emitOpError("dimension map is invalid."); },
164+
[&]() { return op->emitOpError("dimension map is invalid."); },
159165
dimMap))) {
160166
return failure();
161167
}
@@ -164,23 +170,24 @@ LogicalResult ScatterOp::verify() {
164170
return op->emitOpError("dimension map must have at least one element");
165171
}
166172

167-
const size_t indexDepth = getIndexDepth();
168-
auto originalType = getOriginalType();
169-
auto updateType = getUpdateType();
173+
const size_t indexDepth = op.getIndexDepth();
170174
const auto originalSliceRank = originalType.getRank() - indexDepth;
171175
if (originalSliceRank < 0) {
172-
return op->emitOpError(
173-
"expected original rank to be greater or equal to index depth");
176+
return op->emitOpError("expected " + originalName +
177+
" rank to be greater or equal to index depth");
174178
}
175179
if (updateType.getRank() < originalSliceRank) {
176-
return op->emitOpError(
177-
"expected update to be at least the rank of non indexed original dims");
180+
return op->emitOpError("expected " + updateName +
181+
" to be at least the rank of non indexed " +
182+
originalName + " dims");
178183
}
179184
const size_t batchRank = updateType.getRank() - originalSliceRank;
180185

181186
if (updateType.getRank() - batchRank != originalSliceRank) {
182-
return op->emitOpError("expected rank of update value - batch rank to be "
183-
"equal to rank of original value - index depth");
187+
return op->emitOpError("expected rank of " + updateName +
188+
" value - batch rank to be "
189+
"equal to rank of " +
190+
originalName + " value - index depth");
184191
}
185192

186193
if ((indicesType.getRank() != batchRank || indexDepth != 1) &&
@@ -196,8 +203,8 @@ LogicalResult ScatterOp::verify() {
196203
llvm::mismatch(indicesType.getShape().take_front(batchRank),
197204
updateType.getShape().take_front(batchRank));
198205
if (indicesIt != indicesType.getShape().take_front(batchRank).end()) {
199-
return op->emitOpError(
200-
"mismatch in shape of indices and update value at dim#")
206+
return op->emitOpError("mismatch in shape of indices and " + updateName +
207+
" value at dim#")
201208
<< (indicesIt - indicesType.getShape().begin());
202209
}
203210
}
@@ -208,7 +215,7 @@ LogicalResult ScatterOp::verify() {
208215
}
209216

210217
{
211-
for (auto idx : llvm::seq<int64_t>(0, getUpdateSliceRank())) {
218+
for (auto idx : llvm::seq<int64_t>(0, sliceRank)) {
212219
int64_t updateDim = idx + batchRank;
213220
int64_t origDim = idx + indexDepth;
214221
if (originalType.isDynamicDim(origDim) ||
@@ -217,14 +224,14 @@ LogicalResult ScatterOp::verify() {
217224
}
218225
if (originalType.getDimSize(origDim) !=
219226
updateType.getDimSize(updateDim)) {
220-
return op->emitOpError("shape of update value dim#")
221-
<< (updateDim) << " must match original value at dim#"
222-
<< (origDim);
227+
return op->emitOpError("shape of " + updateName + " value dim#")
228+
<< (updateDim)
229+
<< " must match " + originalName + " value at dim#" << (origDim);
223230
}
224231
}
225232
}
226233

227-
Region &region = this->getRegion();
234+
Region &region = op.getRegion();
228235
Block *body = &region.front();
229236
if (body->getNumArguments() != 2) {
230237
return op->emitOpError("expected region to have two arguments");
@@ -238,12 +245,12 @@ LogicalResult ScatterOp::verify() {
238245
}
239246
if (arg0Type != updateType.getElementType()) {
240247
return op->emitOpError("mismatch in argument 0 of region ")
241-
<< arg0Type << " and element type of update value "
248+
<< arg0Type << " and element type of " + updateName + " value "
242249
<< updateType.getElementType();
243250
}
244251
if (arg1Type != originalType.getElementType()) {
245252
return op->emitOpError("mismatch in argument 1 of region ")
246-
<< arg1Type << " and element type of original value "
253+
<< arg1Type << " and element type of " + originalName + " value "
247254
<< originalType.getElementType();
248255
}
249256
if (arg0Type != arg1Type) {
@@ -262,6 +269,15 @@ LogicalResult ScatterOp::verify() {
262269
return success();
263270
}
264271

272+
//===----------------------------------------------------------------------===//
273+
// ScatterOp
274+
//===----------------------------------------------------------------------===//
275+
276+
LogicalResult ScatterOp::verify() {
277+
return verifyGatherScatter(*this, getUpdateSliceRank(), getOriginalType(),
278+
getUpdateType(), "original", "update");
279+
}
280+
265281
LogicalResult
266282
ScatterOp::reifyResultShapes(OpBuilder &b,
267283
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
@@ -285,6 +301,22 @@ SmallVector<AffineMap> ScatterOp::getIndexingMapsForResults() {
285301
return {AffineMap(nullptr)};
286302
}
287303

304+
//===----------------------------------------------------------------------===//
305+
// GatherOp
306+
//===----------------------------------------------------------------------===//
307+
308+
LogicalResult GatherOp::verify() {
309+
return verifyGatherScatter(*this, getOutputSliceRank(), getSourceType(),
310+
getOutputType(), "source", "output");
311+
}
312+
313+
LogicalResult
314+
GatherOp::reifyResultShapes(OpBuilder &b,
315+
ReifiedRankedShapedTypeDims &reifiedReturnShapes) {
316+
return cast<LinalgExtOp>(getOperation())
317+
.reifyResultShapes(b, reifiedReturnShapes);
318+
}
319+
288320
//===----------------------------------------------------------------------===//
289321
// SortOp
290322
//===----------------------------------------------------------------------===//
@@ -1950,6 +1982,7 @@ LogicalResult IREE::LinalgExt::IndexOp::verify() {
19501982
}
19511983

19521984
DEFINE_OP_GET_EFFECTS(ScatterOp)
1985+
DEFINE_OP_GET_EFFECTS(GatherOp)
19531986
DEFINE_OP_GET_EFFECTS(SortOp)
19541987
DEFINE_OP_GET_EFFECTS(FftOp)
19551988
DEFINE_OP_GET_EFFECTS(ScanOp)

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -223,6 +223,95 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
223223
}];
224224
}
225225

226+
def IREELinalgExt_GatherOp : IREELinalgExt_Op<"gather",
227+
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>]> {
228+
let summary = "Gather operator";
229+
let description = [{
230+
Takes two inputs (`source` and `indices`) and outputs value (`output`).
231+
The operation returns the value at the slices specified by `indices` by
232+
combining the gathered value from `source` with the value in `output`
233+
using the computation specified in `region`. The `region` specifies a binary
234+
operation of signature `(T, T) -> T`, where `T` is the element-type of
235+
`source` & `output`. The first argument is from `source` and the second is
236+
from `output`.
237+
238+
The size of the `dimension_map` attribute is used to determine how many
239+
indices are used to index into `source`, i.e. `index_depth`. The
240+
`dimension_map` attribute describes which index value maps to which dimension
241+
in the destination.
242+
243+
This operation preforms the opposite operation of `iree_linalg_ext.scatter`.
244+
Instead of scattering `updates` into `original`, it gathers the values from
245+
`source` into `output` using the indices in `indices`. See the documentation
246+
on `iree_linalg_ext.scatter` for more details regarding the indexing/shape
247+
semantics.
248+
}];
249+
let arguments = (ins
250+
Variadic<AnyRankedTensorOrMemRefType>:$inputs,
251+
Variadic<AnyRankedTensorOrMemRefType>:$outputs,
252+
DenseI64ArrayAttr:$dimension_map
253+
);
254+
let results = (outs Variadic<AnyRankedTensor>:$results);
255+
let regions = (region AnyRegion:$region);
256+
let assemblyFormat = [{
257+
attr-dict `dimension_map` `=` $dimension_map
258+
(`ins` `(` $inputs^ `:` type($inputs) `)`)?
259+
`outs` `(` $outputs `:` type($outputs) `)`
260+
$region (`->` type($results)^)?
261+
}];
262+
263+
let extraClassDeclaration = extraLinalgExtOpClassDeclaration # [{
264+
static constexpr unsigned kSourceOpNum = 0;
265+
static constexpr unsigned kIndicesOpNum = 1;
266+
static constexpr unsigned kResultOpNum = 2;
267+
268+
/// Utility to get the number of indices used to index into `source`.
269+
int64_t getIndexDepth() {
270+
return getDimensionMap().size();
271+
}
272+
273+
/// Utility to get rank of the portion of `output` that is contiguous.
274+
int64_t getOutputSliceRank() {
275+
return getSourceType().getRank() - getIndexDepth();
276+
}
277+
278+
/// Utility to get the rank of the portion of `indices` that represents the
279+
/// batch dimensions.
280+
int64_t getBatchRank() {
281+
return getOutputType().getRank() - getOutputSliceRank();
282+
}
283+
284+
Value getSource(){
285+
return getOperand(kSourceOpNum);
286+
}
287+
288+
ShapedType getSourceType(){
289+
return cast<ShapedType>(getSource().getType());
290+
}
291+
292+
Value getIndices(){
293+
return getOperand(kIndicesOpNum);
294+
}
295+
296+
ShapedType getIndicesType(){
297+
return cast<ShapedType>(getIndices().getType());
298+
}
299+
300+
Value getOutput(){
301+
return getOperand(kResultOpNum);
302+
}
303+
304+
ShapedType getOutputType(){
305+
return cast<ShapedType>(getOutput().getType());
306+
}
307+
308+
/// For DPS interface.
309+
MutableOperandRange getDpsInitsMutable() {
310+
return getOutputsMutable();
311+
}
312+
}];
313+
}
314+
226315
def IREELinalgExt_SortOp : IREELinalgExt_Op<"sort",
227316
[DeclareOpInterfaceMethods<ReifyRankedShapedTypeOpInterface>,
228317
DeclareOpInterfaceMethods<TilingInterface,

compiler/src/iree/compiler/Dialect/LinalgExt/IR/test/invalid.mlir

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -421,6 +421,70 @@ func.func @scatter_index_depth_too_small(
421421

422422
// -----
423423

424+
func.func @gather_output_too_large(
425+
%source : tensor<10xf32>, %idx : tensor<1xi32>,
426+
%output : tensor<2xf32>) -> tensor<2xf32> {
427+
// expected-error @below {{'iree_linalg_ext.gather' op mismatch in shape of indices and output value at dim#0}}
428+
%0 = iree_linalg_ext.gather
429+
dimension_map = [0]
430+
ins(%source, %idx : tensor<10xf32>, tensor<1xi32>)
431+
outs(%output : tensor<2xf32>) {
432+
^bb0(%arg0: f32, %arg1: f32):
433+
iree_linalg_ext.yield %arg0 : f32
434+
} -> tensor<2xf32>
435+
return %0 : tensor<2xf32>
436+
}
437+
438+
// -----
439+
440+
func.func @gather_mismatch_output_and_source(
441+
%source : tensor<10x10xf32>, %idx : tensor<2xi32>,
442+
%output : tensor<1xf32>) -> tensor<1xf32> {
443+
// expected-error @below {{'iree_linalg_ext.gather' op shape of output value dim#0 must match source value at dim#1}}
444+
%0 = iree_linalg_ext.gather
445+
dimension_map = [0]
446+
ins(%source, %idx : tensor<10x10xf32>, tensor<2xi32>)
447+
outs(%output : tensor<1xf32>) {
448+
^bb0(%arg0: f32, %arg1: f32):
449+
iree_linalg_ext.yield %arg0 : f32
450+
} -> tensor<1xf32>
451+
return %0 : tensor<1xf32>
452+
}
453+
454+
// -----
455+
456+
func.func @gather_indices_batch_rank_too_large(
457+
%source : tensor<10x10xf32>, %idx : tensor<1x2xi32>,
458+
%output : tensor<10xf32>) -> tensor<10xf32> {
459+
// expected-error @below {{'iree_linalg_ext.gather' op expected indices to be equal to batch rank or batch rank + 1}}
460+
%0 = iree_linalg_ext.gather
461+
dimension_map = [0]
462+
ins(%source, %idx : tensor<10x10xf32>, tensor<1x2xi32>)
463+
outs(%output : tensor<10xf32>) {
464+
^bb0(%arg0: f32, %arg1: f32):
465+
iree_linalg_ext.yield %arg0 : f32
466+
} -> tensor<10xf32>
467+
return %0 : tensor<10xf32>
468+
}
469+
470+
// -----
471+
472+
func.func @gather_dim_map_mismatch(
473+
%source : tensor<2xf32>, %idx : tensor<1xi32>,
474+
%output : tensor<1xf32>) -> tensor<1xf32> {
475+
// expected-error @below {{'iree_linalg_ext.gather' op expected output to be at least the rank of non indexed source dims}}
476+
%0 = iree_linalg_ext.gather
477+
dimension_map = [0, 1]
478+
ins(%source, %idx : tensor<2xf32>, tensor<1xi32>)
479+
outs(%output : tensor<1xf32>) {
480+
^bb0(%arg0: f32, %arg1: f32):
481+
iree_linalg_ext.yield %arg0 : f32
482+
} -> tensor<1xf32>
483+
return %0 : tensor<1xf32>
484+
}
485+
486+
// -----
487+
424488
func.func @topk_invalid(%input_values: tensor<2x10xf32>, %input_indices: tensor<2x10xi32>, %out_values : tensor<2x3xf32>, %out_indices: tensor<2x3xi32>) -> (tensor<2x3xf32>, tensor<2x3xi32>) {
425489
// expected-error@+1 {{expected one or two input operands}}
426490
%0:2 = iree_linalg_ext.topk

0 commit comments

Comments
 (0)