@@ -187,9 +187,10 @@ LogicalResult ScatterOp::getResultTilePosition(
187187
188188// / Method to return the position of the result tile computed by the tiled
189189// / operation.
190- LogicalResult ScatterOp::getIterationDomainTileFromOperandTile (
191- OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
192- ArrayRef<OpFoldResult> sizes,
190+ LogicalResult ScatterOp::getIterationDomainTileFromOperandTiles (
191+ OpBuilder &b, ArrayRef<unsigned > operandNumbers,
192+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
193+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
193194 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
194195 SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
195196 // Fusion with producers is not possible in general if `unique_indices` is not
@@ -199,9 +200,12 @@ LogicalResult ScatterOp::getIterationDomainTileFromOperandTile(
199200 }
200201 // TODO: Support fusion along the index operand. For the index operand, the
201202 // offset + size must be the full size for the inner most dim.
202- if (getInputs ().getBeginOperandIndex () != operandNumber) {
203+ if (operandNumbers.size () != 1 ||
204+ getInputs ().getBeginOperandIndex () != operandNumbers.front ()) {
203205 return failure ();
204206 }
207+ ArrayRef<OpFoldResult> offsets (allOffsets[0 ]);
208+ ArrayRef<OpFoldResult> sizes (allSizes[0 ]);
205209
206210 // The iteration domain is defined in terms of the |input|, so simply
207211 // use the given offsets/sizes.
@@ -212,12 +216,14 @@ LogicalResult ScatterOp::getIterationDomainTileFromOperandTile(
212216
213217// / Method to generate the tiled implementation of an operation from the tile
214218// / of the operand.
215- FailureOr<TilingResult> ScatterOp::getTiledImplementationFromOperandTile (
216- OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
217- ArrayRef<OpFoldResult> sizes) {
219+ FailureOr<TilingResult> ScatterOp::getTiledImplementationFromOperandTiles (
220+ OpBuilder &b, ArrayRef<unsigned > operandNumbers,
221+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
222+ ArrayRef<SmallVector<OpFoldResult>> allSizes) {
218223 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219- if (failed (getIterationDomainTileFromOperandTile (
220- b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
224+ if (failed (getIterationDomainTileFromOperandTiles (
225+ b, operandNumbers, allOffsets, allSizes, mappedOffsets,
226+ mappedSizes))) {
221227 return failure ();
222228 }
223229 return getTiledImplementation (b, mappedOffsets, mappedSizes);
@@ -500,27 +506,34 @@ LogicalResult MapScatterOp::getResultTilePosition(
500506 return success ();
501507}
502508
503- LogicalResult MapScatterOp::getIterationDomainTileFromOperandTile (
504- OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
505- ArrayRef<OpFoldResult> sizes,
509+ LogicalResult MapScatterOp::getIterationDomainTileFromOperandTiles (
510+ OpBuilder &b, ArrayRef<unsigned > operandNumbers,
511+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
512+ ArrayRef<SmallVector<OpFoldResult>> allSizes,
506513 SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
507514 SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
508- if (operandNumber != getInputMutable ().getOperandNumber ()) {
515+ if (operandNumbers.size () != 1 ||
516+ operandNumbers.front () != getInputMutable ().getOperandNumber ()) {
509517 return failure ();
510518 }
519+ ArrayRef<OpFoldResult> offsets (allOffsets[0 ]);
520+ ArrayRef<OpFoldResult> sizes (allSizes[0 ]);
521+
511522 // The iteration domain is defined in terms of the `input`, so simply
512523 // use the given offsets/sizes.
513524 iterDomainOffsets.assign (offsets.begin (), offsets.end ());
514525 iterDomainSizes.assign (sizes.begin (), sizes.end ());
515526 return success ();
516527}
517528
518- FailureOr<TilingResult> MapScatterOp::getTiledImplementationFromOperandTile (
519- OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
520- ArrayRef<OpFoldResult> sizes) {
529+ FailureOr<TilingResult> MapScatterOp::getTiledImplementationFromOperandTiles (
530+ OpBuilder &b, ArrayRef<unsigned > operandNumbers,
531+ ArrayRef<SmallVector<OpFoldResult>> allOffsets,
532+ ArrayRef<SmallVector<OpFoldResult>> allSizes) {
521533 SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
522- if (failed (getIterationDomainTileFromOperandTile (
523- b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
534+ if (failed (getIterationDomainTileFromOperandTiles (
535+ b, operandNumbers, allOffsets, allSizes, mappedOffsets,
536+ mappedSizes))) {
524537 return failure ();
525538 }
526539 return getTiledImplementation (b, mappedOffsets, mappedSizes);
0 commit comments