Skip to content

Commit e1b6ddc

Browse files
MaheshRavishankarlialan
authored andcommitted
Signed-off-by: MaheshRavishankar <[email protected]>
1 parent ead218f commit e1b6ddc

File tree

5 files changed

+46
-31
lines changed

5 files changed

+46
-31
lines changed

compiler/src/iree/compiler/Codegen/Common/GPU/GPUFuseAndHoistParallelLoops.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -320,7 +320,8 @@ struct FuseTilableForallConsumers final
320320
}
321321

322322
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
323-
scf::tileAndFuseConsumerOfSlice(rewriter, producerSlice, {sliceOwner});
323+
scf::tileAndFuseConsumerOfSlices(rewriter, producerSlice.getOperation(),
324+
{sliceOwner});
324325
if (failed(fuseConsumerResults)) {
325326
return failure();
326327
}

compiler/src/iree/compiler/Codegen/Common/TileAndFuseUtils.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -162,8 +162,8 @@ fuseConsumersIntoLoops(RewriterBase &rewriter, Operation *tiledOp,
162162
candidates.pop();
163163

164164
FailureOr<scf::SCFFuseConsumerOfSliceResult> fusedResult =
165-
mlir::scf::tileAndFuseConsumerOfSlice(rewriter, candidateSliceOp,
166-
loops);
165+
mlir::scf::tileAndFuseConsumerOfSlices(
166+
rewriter, candidateSliceOp.getOperation(), loops);
167167
if (failed(fusedResult)) {
168168
LLVM_DEBUG(llvm::dbgs() << "failed to fuse consumer of slice: "
169169
<< candidateSliceOp << "\n");
@@ -190,14 +190,15 @@ fuseConsumersIntoLoops(RewriterBase &rewriter, Operation *tiledOp,
190190
}
191191

192192
// Replace the original consumer operation with the tiled implementation.
193-
rewriter.replaceOp(fusedResult->origConsumerOperand->getOwner(),
193+
rewriter.replaceOp(fusedResult->origConsumerOperands.front()->getOwner(),
194194
fusedResult->tiledOps.front());
195195

196196
// The result of the fused consumers might themselves be slices of
197197
// values produced by operations that implement the `TilingInterface`.
198198
// Add these operations to the worklist.
199-
addCandidateSlices(fusedResult->tiledAndFusedConsumerOperand->getOwner(),
200-
candidates);
199+
addCandidateSlices(
200+
fusedResult->tiledAndFusedConsumerOperands.front()->getOwner(),
201+
candidates);
201202

202203
// Add the list of new producer fusion opportunities.
203204
for (auto tiledOp : fusedResult.value().tiledOps) {

compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1217,16 +1217,16 @@ applyFuseConsumer(RewriterBase &rewriter, Operation *transformOp,
12171217
rewriter.setInsertionPoint(target);
12181218

12191219
FailureOr<scf::SCFFuseConsumerOfSliceResult> fuseConsumerResults =
1220-
scf::tileAndFuseConsumerOfSlice(rewriter, target, loops);
1220+
scf::tileAndFuseConsumerOfSlices(rewriter, target, loops);
12211221

12221222
if (failed(fuseConsumerResults))
12231223
return failure();
12241224

12251225
// Report back the relevant handles to the transform op.
12261226
originalConsumerOps.push_back(
1227-
fuseConsumerResults->origConsumerOperand->getOwner());
1227+
fuseConsumerResults->origConsumerOperands.front()->getOwner());
12281228
fusedConsumerOps.push_back(
1229-
fuseConsumerResults->tiledAndFusedConsumerOperand->getOwner());
1229+
fuseConsumerResults->tiledAndFusedConsumerOperands.front()->getOwner());
12301230
}
12311231

12321232
transformResults.set(transformOp->getOpResult(0), originalConsumerOps);

compiler/src/iree/compiler/Dialect/LinalgExt/IR/LinalgExtOps.td

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def IREELinalgExt_ScatterOp : IREELinalgExt_Op<"scatter",
103103
"getLoopIteratorTypes",
104104
"getResultTilePosition",
105105
"getTiledImplementation",
106-
"getIterationDomainTileFromOperandTile",
107-
"getTiledImplementationFromOperandTile"]>]> {
106+
"getIterationDomainTileFromOperandTiles",
107+
"getTiledImplementationFromOperandTiles"]>]> {
108108
let summary = [{Scatters an input in slices based on a tensor of indices.}];
109109
let description = [{
110110
Takes two `inputs` (`update` and `indices`) and `outputs` value (`original`).
@@ -326,8 +326,8 @@ def IREELinalgExt_MapScatterOp : IREELinalgExt_PureOp<"map_scatter",
326326
"getLoopIteratorTypes",
327327
"getResultTilePosition",
328328
"getTiledImplementation",
329-
"getIterationDomainTileFromOperandTile",
330-
"getTiledImplementationFromOperandTile",
329+
"getIterationDomainTileFromOperandTiles",
330+
"getTiledImplementationFromOperandTiles",
331331
"generateScalarImplementation"]>,
332332
SingleBlockImplicitTerminator<"::mlir::iree_compiler::IREE::LinalgExt::YieldOp">]> {
333333
let summary = [{Scatter with a mapping from source indices to result indices.}];

compiler/src/iree/compiler/Dialect/LinalgExt/IR/TilingInterfaceImpl.cpp

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -187,9 +187,10 @@ LogicalResult ScatterOp::getResultTilePosition(
187187

188188
/// Method to return the position of the result tile computed by the tiled
189189
/// operation.
190-
LogicalResult ScatterOp::getIterationDomainTileFromOperandTile(
191-
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
192-
ArrayRef<OpFoldResult> sizes,
190+
LogicalResult ScatterOp::getIterationDomainTileFromOperandTiles(
191+
OpBuilder &b, ArrayRef<unsigned> operandNumbers,
192+
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
193+
ArrayRef<SmallVector<OpFoldResult>> allSizes,
193194
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
194195
SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
195196
// Fusion with producers is not possible in general if `unique_indices` is not
@@ -199,9 +200,12 @@ LogicalResult ScatterOp::getIterationDomainTileFromOperandTile(
199200
}
200201
// TODO: Support fusion along the index operand. For the index operand, the
201202
// offset + size must be the full size for the inner most dim.
202-
if (getInputs().getBeginOperandIndex() != operandNumber) {
203+
if (operandNumbers.size() != 1 ||
204+
getInputs().getBeginOperandIndex() != operandNumbers.front()) {
203205
return failure();
204206
}
207+
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
208+
ArrayRef<OpFoldResult> sizes(allSizes[0]);
205209

206210
// The iteration domain is defined in terms of the |input|, so simply
207211
// use the given offsets/sizes.
@@ -212,12 +216,14 @@ LogicalResult ScatterOp::getIterationDomainTileFromOperandTile(
212216

213217
/// Method to generate the tiled implementation of an operation from the tile
214218
/// of the operand.
215-
FailureOr<TilingResult> ScatterOp::getTiledImplementationFromOperandTile(
216-
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
217-
ArrayRef<OpFoldResult> sizes) {
219+
FailureOr<TilingResult> ScatterOp::getTiledImplementationFromOperandTiles(
220+
OpBuilder &b, ArrayRef<unsigned> operandNumbers,
221+
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
222+
ArrayRef<SmallVector<OpFoldResult>> allSizes) {
218223
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
219-
if (failed(getIterationDomainTileFromOperandTile(
220-
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
224+
if (failed(getIterationDomainTileFromOperandTiles(
225+
b, operandNumbers, allOffsets, allSizes, mappedOffsets,
226+
mappedSizes))) {
221227
return failure();
222228
}
223229
return getTiledImplementation(b, mappedOffsets, mappedSizes);
@@ -500,27 +506,34 @@ LogicalResult MapScatterOp::getResultTilePosition(
500506
return success();
501507
}
502508

503-
LogicalResult MapScatterOp::getIterationDomainTileFromOperandTile(
504-
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
505-
ArrayRef<OpFoldResult> sizes,
509+
LogicalResult MapScatterOp::getIterationDomainTileFromOperandTiles(
510+
OpBuilder &b, ArrayRef<unsigned> operandNumbers,
511+
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
512+
ArrayRef<SmallVector<OpFoldResult>> allSizes,
506513
SmallVectorImpl<OpFoldResult> &iterDomainOffsets,
507514
SmallVectorImpl<OpFoldResult> &iterDomainSizes) {
508-
if (operandNumber != getInputMutable().getOperandNumber()) {
515+
if (operandNumbers.size() != 1 ||
516+
operandNumbers.front() != getInputMutable().getOperandNumber()) {
509517
return failure();
510518
}
519+
ArrayRef<OpFoldResult> offsets(allOffsets[0]);
520+
ArrayRef<OpFoldResult> sizes(allSizes[0]);
521+
511522
// The iteration domain is defined in terms of the `input`, so simply
512523
// use the given offsets/sizes.
513524
iterDomainOffsets.assign(offsets.begin(), offsets.end());
514525
iterDomainSizes.assign(sizes.begin(), sizes.end());
515526
return success();
516527
}
517528

518-
FailureOr<TilingResult> MapScatterOp::getTiledImplementationFromOperandTile(
519-
OpBuilder &b, unsigned operandNumber, ArrayRef<OpFoldResult> offsets,
520-
ArrayRef<OpFoldResult> sizes) {
529+
FailureOr<TilingResult> MapScatterOp::getTiledImplementationFromOperandTiles(
530+
OpBuilder &b, ArrayRef<unsigned> operandNumbers,
531+
ArrayRef<SmallVector<OpFoldResult>> allOffsets,
532+
ArrayRef<SmallVector<OpFoldResult>> allSizes) {
521533
SmallVector<OpFoldResult> mappedOffsets, mappedSizes;
522-
if (failed(getIterationDomainTileFromOperandTile(
523-
b, operandNumber, offsets, sizes, mappedOffsets, mappedSizes))) {
534+
if (failed(getIterationDomainTileFromOperandTiles(
535+
b, operandNumbers, allOffsets, allSizes, mappedOffsets,
536+
mappedSizes))) {
524537
return failure();
525538
}
526539
return getTiledImplementation(b, mappedOffsets, mappedSizes);

0 commit comments

Comments
 (0)