@@ -88,7 +88,7 @@ SmallVector<Range> ScatterOp::getIterationDomain(OpBuilder &builder) {
8888 SmallVector<Range> ranges;
8989 for (auto dim : llvm::seq<int64_t >(0 , getUpdateType ().getRank ())) {
9090 OpFoldResult ub = getDim (builder, loc, getUpdates (), dim);
91- ranges.emplace_back (Range{zero, ub, one});
91+ ranges.push_back (Range{zero, ub, one});
9292 }
9393 return ranges;
9494}
@@ -277,6 +277,117 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b,
277277 return success ();
278278}
279279
280+ // ===----------------------------------------------------------------------===//
281+ // GatherOp
282+ // ===----------------------------------------------------------------------===//
283+
284+ SmallVector<utils::IteratorType> GatherOp::getLoopIteratorTypes () {
285+ return SmallVector<utils::IteratorType>(getOutputType ().getRank (),
286+ utils::IteratorType::parallel);
287+ }
288+
289+ SmallVector<Range> GatherOp::getIterationDomain (OpBuilder &builder) {
290+ Location loc = getLoc ();
291+ OpFoldResult zero = builder.getIndexAttr (0 );
292+ OpFoldResult one = builder.getIndexAttr (1 );
293+ SmallVector<Range> ranges;
294+ for (auto dim : llvm::seq<int64_t >(0 , getOutputType ().getRank ())) {
295+ OpFoldResult ub = getDim (builder, loc, getOutput (), dim);
296+ ranges.push_back (Range{zero, ub, one});
297+ }
298+ return ranges;
299+ }
300+
301+ FailureOr<TilingResult>
302+ GatherOp::getTiledImplementation (OpBuilder &builder,
303+ ArrayRef<OpFoldResult> offsets,
304+ ArrayRef<OpFoldResult> sizes) {
305+ assert (offsets.size () >= 1 && sizes.size () >= 1 );
306+ Location loc = getLoc ();
307+ auto zeroAttr = builder.getI64IntegerAttr (0 );
308+ auto oneAttr = builder.getI64IntegerAttr (1 );
309+ SmallVector<Operation *> slices;
310+
311+ // Slice of the result.
312+ auto resultRank = getOutputType ().getRank ();
313+ SmallVector<OpFoldResult> resultStrides (resultRank, oneAttr);
314+ Operation *resultSlice =
315+ getSlice (builder, loc, getOutput (), offsets, sizes, resultStrides);
316+ if (!resultSlice) {
317+ return emitOpError (" failed to get result slice" );
318+ }
319+ Value tiledResult = resultSlice->getResult (0 );
320+
321+ // Slice of indices.
322+ auto indicesRank = getIndicesType ().getRank ();
323+ SmallVector<OpFoldResult> indicesOffsets (offsets.take_front (getBatchRank ()));
324+ SmallVector<OpFoldResult> indicesSizes (sizes.take_front (getBatchRank ()));
325+ if (getBatchRank () != getIndicesType ().getRank ()) {
326+ indicesOffsets.push_back (zeroAttr);
327+ indicesSizes.push_back (builder.getIndexAttr (getIndexDepth ()));
328+ }
329+ SmallVector<OpFoldResult> indicesStrides (indicesRank, oneAttr);
330+
331+ Operation *indicesSlice = getSlice (builder, loc, getIndices (), indicesOffsets,
332+ indicesSizes, indicesStrides);
333+ if (!indicesSlice) {
334+ return emitOpError (" failed to get indices slices" );
335+ }
336+ Value tiledIndices = indicesSlice->getResult (0 );
337+
338+ // Slice of the source.
339+ auto sourceRank = getSourceType ().getRank ();
340+ auto indexDepth = getIndexDepth ();
341+
342+ // The first `indexDepth` dims are not tiled
343+ SmallVector<OpFoldResult> sourceOffsets, sourceSizes;
344+ for (auto dim : llvm::seq<int64_t >(0 , indexDepth)) {
345+ sourceOffsets.push_back (zeroAttr);
346+ sourceSizes.push_back (getDim (builder, loc, getSource (), dim));
347+ }
348+ llvm::append_range (sourceOffsets,
349+ offsets.slice (getBatchRank (), sourceRank - indexDepth));
350+ llvm::append_range (sourceSizes,
351+ sizes.slice (getBatchRank (), sourceRank - indexDepth));
352+ SmallVector<OpFoldResult> sourceStrides (sourceRank, oneAttr);
353+ Operation *sourceSlice = getSlice (builder, loc, getSource (), sourceOffsets,
354+ sourceSizes, sourceStrides);
355+ if (!sourceSlice) {
356+ return emitOpError (" failed to get source tensor slice" );
357+ }
358+ Value tiledSource = sourceSlice->getResult (0 );
359+
360+ slices.push_back (sourceSlice);
361+ slices.push_back (indicesSlice);
362+ slices.push_back (resultSlice);
363+
364+ SmallVector<Type> resultTypes;
365+ if (getNumResults ()) {
366+ resultTypes.push_back (tiledResult.getType ());
367+ }
368+ Operation *tiledGatherOp =
369+ mlir::clone (builder, getOperation (), resultTypes,
370+ ValueRange{tiledSource, tiledIndices, tiledResult});
371+ return TilingResult{
372+ {tiledGatherOp}, SmallVector<Value>(tiledGatherOp->getResults ()), slices};
373+ }
374+
375+ LogicalResult GatherOp::getResultTilePosition (
376+ OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
377+ ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
378+ SmallVector<OpFoldResult> &resultSizes) {
379+ resultOffsets.assign (offsets.begin (), offsets.end ());
380+ resultSizes.assign (sizes.begin (), sizes.end ());
381+ return success ();
382+ }
383+
384+ FailureOr<TilingResult>
385+ GatherOp::generateResultTileValue (OpBuilder &builder, unsigned resultNumber,
386+ ArrayRef<OpFoldResult> offsets,
387+ ArrayRef<OpFoldResult> sizes) {
388+ return getTiledImplementation (builder, offsets, sizes);
389+ }
390+
280391// ===----------------------------------------------------------------------===//
281392// SortOp
282393// ===----------------------------------------------------------------------===//
0 commit comments