Skip to content

Commit 90ce1b2

Browse files
Merge branch 'main' into vinay-issue-115394
2 parents bfe6b86 + 91bbebc commit 90ce1b2

File tree

5 files changed

+225
-91
lines changed

5 files changed

+225
-91
lines changed

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
427427
/*defaultImplementation=*/[{
428428
return failure();
429429
}]
430+
>,
431+
InterfaceMethod<
432+
/*desc=*/[{
433+
Method to return the position of the partial result tile computed by
434+
the tiled operation. This is same as
435+
TilingInterface:::getResultTilePosition, but determines the result
436+
tile position for partial reduction.
437+
}],
438+
/*retType=*/"::llvm::LogicalResult",
439+
/*methodName=*/"getPartialResultTilePosition",
440+
/*args=*/(ins
441+
"::mlir::OpBuilder &":$b,
442+
"unsigned":$resultNumber,
443+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
444+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
445+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
446+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
447+
"::mlir::ArrayRef<int>":$reductionDims),
448+
/*methodBody=*/"",
449+
/*defaultImplementation=*/[{
450+
return failure();
451+
}]
430452
>
431453
];
432454
}

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
105105
static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106106
ArrayRef<MeshSharding> resultShardings,
107107
SymbolTableCollection &symbolTable) {
108-
for (const MeshSharding& sharding : operandShardings) {
108+
for (const MeshSharding &sharding : operandShardings) {
109109
if (sharding) {
110110
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111111
}
112112
}
113113

114-
for (const MeshSharding& sharding : resultShardings) {
114+
for (const MeshSharding &sharding : resultShardings) {
115115
if (sharding) {
116116
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117117
}
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
129129
// the original operand.
130130
// The other processes would use the reduction operation neutral tensor.
131131
static Value createDestinationPassingStyleInitOperand(
132-
LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
133-
MeshOp meshOp, ImplicitLocOpBuilder &builder) {
132+
LinalgOp op, int operandNumber, Value spmdizedOperand,
133+
ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134+
ImplicitLocOpBuilder &builder) {
134135
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
135136
meshOp.getSymName(), reductionMeshAxes, builder);
136137
Value zero = builder.create<arith::ConstantIndexOp>(0);
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
152153
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
153154
SmallVector<OpFoldResult> shape =
154155
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
155-
PartialReductionOpInterface partialReductionIface =
156-
llvm::cast<PartialReductionOpInterface>(op.getOperation());
157-
assert(op->getNumResults() == 1 && "Multiple results not supported.");
158-
FailureOr<SmallVector<Value>> reductionNeutralTensor =
159-
partialReductionIface.generateInitialTensorForPartialReduction(
160-
builder, builder.getLoc(), shape, {});
161-
assert(succeeded(reductionNeutralTensor));
162-
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
156+
157+
SmallVector<Operation *> combinerOps;
158+
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159+
assert(combinerOps.size() == 1);
160+
std::optional<TypedAttr> neutralEl =
161+
arith::getNeutralElement(combinerOps[0]);
162+
163+
Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164+
neutralEl.value().getType());
165+
Value constant =
166+
builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167+
Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168+
.getResult(0);
169+
170+
builder.create<scf::YieldOp>(fill);
163171
}
164172
return ifOp.getResult(0);
165173
}
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
178186
Value spmdizedInitOperand =
179187
spmdizationMap.lookup(op->getOperands()[operandIdx]);
180188
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
181-
op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
189+
op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
182190
return newOperands;
183191
}
184192

mlir/lib/Dialect/Linalg/Transforms/TilingInterfaceImpl.cpp

Lines changed: 113 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,27 @@ struct LinalgOpTilingInterface
324324
// External Model for implementing `PartialReductionInterface` for `LinalgOp`s.
325325
//===----------------------------------------------------------------------===//
326326

327-
/// External model implementation of PartialReductionInterface for LinalgOps.
327+
/// Return an AffineMap for a partial result for the given result number,
328+
/// assuming the partial tiling strategy is outer-reduction loop +
329+
/// inner-parallel tile. The returned AffineMap can be used as the replacement
330+
/// AffineMap for the inner-parallel tile linalg op for the given result number.
331+
///
332+
/// The new AffineMap is the old AffineMap with reduction dimensions appended
333+
/// at end.
334+
static AffineMap getPartialResultAffineMap(LinalgOp linalgOp,
335+
ArrayRef<int> reductionDims,
336+
unsigned resultNumber) {
337+
AffineMap map =
338+
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(resultNumber));
339+
for (int redPos : reductionDims) {
340+
map = map.insertResult(getAffineDimExpr(redPos, linalgOp.getContext()),
341+
map.getNumResults());
342+
}
343+
return map;
344+
}
345+
346+
/// External model implementation of PartialReductionInterface for
347+
/// LinalgOps.
328348
template <typename LinalgOpTy>
329349
struct LinalgOpPartialReductionInterface
330350
: public PartialReductionOpInterface::ExternalModel<
@@ -338,11 +358,24 @@ struct LinalgOpPartialReductionInterface
338358
if (linalgOp.hasPureBufferSemantics())
339359
return op->emitOpError("expected operation to have tensor semantics");
340360

361+
// LinalgOp implements TilingInterface.
362+
auto tilingInterfaceOp = cast<TilingInterface>(linalgOp.getOperation());
363+
SmallVector<OpFoldResult> shape =
364+
llvm::map_to_vector(tilingInterfaceOp.getIterationDomain(b),
365+
[](Range x) { return x.size; });
366+
367+
SmallVector<OpFoldResult> tiledShape;
368+
for (auto [tileSize, dimSize] : llvm::zip_equal(sizes, shape)) {
369+
if (isZeroIndex(tileSize)) {
370+
tiledShape.push_back(dimSize);
371+
} else {
372+
tiledShape.push_back(tileSize);
373+
}
374+
}
375+
341376
SmallVector<Value> inits;
342377
for (int initIdx = 0, e = linalgOp.getNumDpsInits(); initIdx < e;
343378
++initIdx) {
344-
// Insert the new parallel dimension based on the index of the reduction
345-
// loops. This could be controlled by user for more flexibility.
346379
SmallVector<Operation *, 4> combinerOps;
347380
if (!matchReduction(linalgOp.getRegionOutputArgs(), initIdx,
348381
combinerOps) ||
@@ -355,33 +388,19 @@ struct LinalgOpPartialReductionInterface
355388
return op->emitOpError(
356389
"Failed to get an identity value for the reduction operation.");
357390

358-
ArrayRef<int64_t> oldShape =
359-
linalgOp.getShape(linalgOp.getDpsInitOperand(initIdx));
360-
361-
// Calculate the new shape, we insert the new dimensions based on the
362-
// index of the reduction dimensions.
363-
SmallVector<int64_t> newOutputShape;
364-
SmallVector<Value> dynamicDims;
365-
int64_t currReductionDims = 0;
366-
DenseSet<int> reductionDimsSet(reductionDims.begin(),
367-
reductionDims.end());
368-
for (int64_t idx :
369-
llvm::seq<int64_t>(0, oldShape.size() + reductionDims.size())) {
370-
if (reductionDimsSet.contains(idx)) {
371-
dispatchIndexOpFoldResults(sizes[idx], dynamicDims, newOutputShape);
372-
currReductionDims++;
373-
continue;
374-
}
375-
int64_t oldIdx = idx - currReductionDims;
376-
int64_t dim = oldShape[oldIdx];
377-
newOutputShape.push_back(dim);
378-
if (ShapedType::isDynamic(dim))
379-
dynamicDims.push_back(b.create<tensor::DimOp>(
380-
loc, linalgOp.getDpsInitOperand(initIdx)->get(), oldIdx));
391+
// Append the new partial result dimensions.
392+
AffineMap partialMap =
393+
getPartialResultAffineMap(linalgOp, reductionDims, initIdx);
394+
SmallVector<OpFoldResult> partialResultShape;
395+
for (AffineExpr dimExpr : partialMap.getResults()) {
396+
auto dim = cast<AffineDimExpr>(dimExpr);
397+
partialResultShape.push_back(tiledShape[dim.getPosition()]);
381398
}
382-
Value emptyTensor = b.create<tensor::EmptyOp>(
383-
loc, newOutputShape,
384-
linalgOp.getRegionOutputArgs()[initIdx].getType(), dynamicDims);
399+
400+
Type elType =
401+
getElementTypeOrSelf(linalgOp->getResult(initIdx).getType());
402+
Value emptyTensor =
403+
b.create<tensor::EmptyOp>(loc, partialResultShape, elType);
385404
Value constantOp = b.create<arith::ConstantOp>(loc, *identity);
386405
auto identityTensor =
387406
b.create<linalg::FillOp>(loc, constantOp, emptyTensor);
@@ -407,11 +426,7 @@ struct LinalgOpPartialReductionInterface
407426
// TODO: linalg::Generic doesn't have getDpsInitOperands. Can replace
408427
// this with a for range loop when we have it.
409428
AffineMap newMap =
410-
linalgOp.getMatchingIndexingMap(linalgOp.getDpsInitOperand(idx));
411-
for (int redPos : reductionDims) {
412-
newMap = newMap.insertResult(b.getAffineDimExpr(redPos),
413-
newMap.getNumResults());
414-
}
429+
getPartialResultAffineMap(linalgOp, reductionDims, idx);
415430
newInitMaps.push_back(newMap);
416431
}
417432

@@ -476,29 +491,75 @@ struct LinalgOpPartialReductionInterface
476491
Location loc, ValueRange partialReduce,
477492
ArrayRef<int> reductionDims) const {
478493
auto linalgOp = cast<LinalgOp>(op);
479-
SmallVector<int64_t> reductionDimsInt64(reductionDims);
480-
auto reduction = b.create<linalg::ReduceOp>(
481-
loc, partialReduce, linalgOp.getDpsInits(), reductionDimsInt64,
482-
[&linalgOp](OpBuilder &b, Location loc, ValueRange inputs) {
483-
int64_t numInits = linalgOp.getNumDpsInits();
484-
SmallVector<Value> yieldedValues;
485-
for (int idx : llvm::seq<int>(0, numInits)) {
494+
495+
// Permute the reduction dims as permuted by the partial result map.
496+
497+
int64_t numInits = linalgOp.getNumDpsInits();
498+
SmallVector<Operation *> mergeOperations;
499+
SmallVector<Value> replacements;
500+
for (int idx : llvm::seq(numInits)) {
501+
// linalg.reduce's iteration space is the tiled result's iteration space
502+
// (and not the tiled operation's iteration space). To account for this,
503+
// permute the reduction dimensions based on the partial result map of the
504+
// tiled result.
505+
AffineMap partialMap =
506+
getPartialResultAffineMap(linalgOp, reductionDims, idx);
507+
SmallVector<int64_t> partialReductionDims;
508+
for (auto [resultNum, dimExpr] :
509+
llvm::enumerate(partialMap.getResults())) {
510+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
511+
if (llvm::find(reductionDims, dim) != reductionDims.end()) {
512+
partialReductionDims.push_back(resultNum);
513+
}
514+
}
515+
516+
Value partialResult = partialReduce[idx];
517+
Value init = linalgOp.getDpsInits()[idx];
518+
519+
auto reduction = b.create<linalg::ReduceOp>(
520+
loc, partialResult, init, partialReductionDims,
521+
[&linalgOp, &idx](OpBuilder &b, Location loc, ValueRange inputs) {
486522
// Get the combiner op.
487523
SmallVector<Operation *, 4> combinerOps;
488524
matchReduction(linalgOp.getRegionOutputArgs(), idx, combinerOps);
489525
Operation *clonedReductionOp = b.clone(*combinerOps[0]);
490526
// Combine the input at idx and output at numInits + idx.
491-
clonedReductionOp->setOperand(0, inputs[idx]);
492-
clonedReductionOp->setOperand(1, inputs[numInits + idx]);
493-
// Yield.
494-
yieldedValues.push_back(clonedReductionOp->getResult(0));
495-
}
496-
b.create<linalg::YieldOp>(loc, yieldedValues);
497-
});
498-
return MergeResult{
499-
{reduction.getOperation()},
500-
llvm::map_to_vector(reduction->getResults(),
501-
[](OpResult r) -> Value { return r; })};
527+
clonedReductionOp->setOperand(0, inputs[0]);
528+
clonedReductionOp->setOperand(1, inputs[1]);
529+
b.create<linalg::YieldOp>(loc, clonedReductionOp->getResult(0));
530+
});
531+
532+
mergeOperations.push_back(reduction);
533+
replacements.push_back(reduction->getResult(0));
534+
}
535+
536+
return MergeResult{mergeOperations, replacements};
537+
}
538+
539+
LogicalResult getPartialResultTilePosition(
540+
Operation *op, OpBuilder &b, unsigned resultNumber,
541+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
542+
SmallVector<OpFoldResult> &resultOffsets,
543+
SmallVector<OpFoldResult> &resultSizes,
544+
ArrayRef<int> reductionDims) const {
545+
auto linalgOp = cast<LinalgOp>(op);
546+
547+
AffineMap partialMap =
548+
getPartialResultAffineMap(linalgOp, reductionDims, resultNumber);
549+
for (AffineExpr dimExpr : partialMap.getResults()) {
550+
unsigned dim = cast<AffineDimExpr>(dimExpr).getPosition();
551+
resultSizes.push_back(sizes[dim]);
552+
553+
if (llvm::find(reductionDims, dim) != reductionDims.end()) {
554+
// Reduction dims are reduced, and are always outputed in the same
555+
// place. So use offset 0 for them.
556+
resultOffsets.push_back(b.getIndexAttr(0));
557+
} else {
558+
resultOffsets.push_back(offsets[dim]);
559+
}
560+
}
561+
562+
return success();
502563
}
503564
};
504565

mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
657657
resultOffset, resultSize);
658658
case scf::SCFTilingOptions::ReductionTilingStrategy::
659659
PartialReductionOuterReduction: {
660-
// TODO: This does not work for non identity accesses to the result tile.
661-
// The proper fix is to add a getPartialResultTilePosition method to
662-
// PartialReductionOpInterface.
663-
resultOffset =
664-
SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
665-
for (size_t i = 0; i < offsets.size(); i++) {
666-
resultSize.push_back(
667-
tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
660+
auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
661+
if (!redOp) {
662+
return rewriter.notifyMatchFailure(
663+
op, "PartialReductionOuterReduction tiling strategy is only supported"
664+
"for operations implementing PartialReductionOpInterface");
668665
}
669-
return success();
666+
// Get reduction dimensions.
667+
// TODO: PartialReductionOpInterface should really query TilingInterface
668+
// itself and find reduction dimensions.
669+
SmallVector<int> reductionDims;
670+
for (auto [idx, iteratorType] :
671+
llvm::enumerate(op.getLoopIteratorTypes())) {
672+
if (iteratorType == utils::IteratorType::reduction)
673+
reductionDims.push_back(idx);
674+
}
675+
return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
676+
resultOffset, resultSize,
677+
reductionDims);
678+
}
670679
default:
671680
return rewriter.notifyMatchFailure(op,
672681
"unhandled reduction tiling strategy");
673682
}
674-
}
675683
}
676684

677685
static FailureOr<MergeResult>

0 commit comments

Comments
 (0)