Skip to content

Commit 17723e4

Browse files
authored
[mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_forall (#157932)
Following [PR #120118](#120118), this PR extends transform.structured.tile_reduction_using_forall so that it can be applied to any operation implementing `PartialReductionOpInterface`, rather than being restricted to LinalgOp. Existing tests relevant to linalg ops remain valid: https://github.com/llvm/llvm-project/blob/2a2296b1aab4614bf6c95c3003000832c9d43de5/mlir/test/Dialect/Linalg/transform-tile-reduction.mlir#L114 Additional tests for non-Linalg operations (e.g., IREE custom ops that implement `PartialReductionOpInterface`) will be added on the IREE side. Signed-off-by: Bangtian Liu <[email protected]>
1 parent c745c54 commit 17723e4

File tree

2 files changed

+14
-7
lines changed

2 files changed

+14
-7
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2017,8 +2017,8 @@ def TileReductionUsingForallOp :
20172017
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes,
20182018
OptionalAttr<DeviceMappingArrayAttr>:$mapping);
20192019
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
2020-
TransformHandleTypeInterface:$split_linalg_op,
2021-
TransformHandleTypeInterface:$combining_linalg_op,
2020+
TransformHandleTypeInterface:$split_op,
2021+
TransformHandleTypeInterface:$combining_op,
20222022
TransformHandleTypeInterface:$forall_op);
20232023

20242024
let builders = [
@@ -2042,7 +2042,7 @@ def TileReductionUsingForallOp :
20422042
let extraClassDeclaration = [{
20432043
::mlir::DiagnosedSilenceableFailure applyToOne(
20442044
::mlir::transform::TransformRewriter &rewriter,
2045-
::mlir::linalg::LinalgOp target,
2045+
Operation *target,
20462046
::mlir::transform::ApplyToEachResultList &results,
20472047
::mlir::transform::TransformState &state);
20482048
}];

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3033,10 +3033,17 @@ void transform::TileReductionUsingForallOp::build(
30333033
}
30343034

30353035
DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
3036-
transform::TransformRewriter &rewriter, LinalgOp target,
3036+
transform::TransformRewriter &rewriter, Operation *target,
30373037
transform::ApplyToEachResultList &results,
30383038
transform::TransformState &state) {
30393039
rewriter.setInsertionPoint(target);
3040+
3041+
auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
3042+
if (!partialReductionOp) {
3043+
return emitSilenceableFailure(
3044+
target->getLoc(),
3045+
"Operation should implement PartialReductionOpInterface");
3046+
}
30403047
SmallVector<OpFoldResult> numThreads =
30413048
getAsOpFoldResult(rewriter.getI64ArrayAttr(getNumThreads()));
30423049
SmallVector<OpFoldResult> tileSizes =
@@ -3058,14 +3065,14 @@ DiagnosedSilenceableFailure transform::TileReductionUsingForallOp::applyToOne(
30583065
extractFromIntegerArrayAttr<unsigned>(getReductionDims());
30593066
if (reductionDims.empty()) {
30603067
for (auto [idx, iteratorType] :
3061-
llvm::enumerate(target.getIteratorTypesArray())) {
3068+
llvm::enumerate(partialReductionOp.getLoopIteratorTypes())) {
30623069
if (iteratorType == utils::IteratorType::reduction)
30633070
reductionDims.push_back(idx);
30643071
}
30653072
}
30663073
options.setReductionDims(reductionDims);
3067-
FailureOr<scf::SCFTilingResult> result = scf::tileUsingSCF(
3068-
rewriter, cast<TilingInterface>(target.getOperation()), options);
3074+
FailureOr<scf::SCFTilingResult> result =
3075+
scf::tileUsingSCF(rewriter, partialReductionOp, options);
30693076

30703077
if (failed(result)) {
30713078
auto diag = emitSilenceableError() << "could not tile reduction";

0 commit comments

Comments
 (0)