Skip to content

Commit 33db61c

Browse files
committed
[mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for
1 parent d098ce0 commit 33db61c

File tree

2 files changed

+12
-5
lines changed

2 files changed

+12
-5
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1765,8 +1765,8 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
17651765
let arguments = (ins TransformHandleTypeInterface:$target,
17661766
DefaultValuedAttr<DenseI64ArrayAttr, "{}">:$tile_sizes);
17671767
let results = (outs Variadic<TransformHandleTypeInterface>:$fill_op,
1768-
TransformHandleTypeInterface:$split_linalg_op,
1769-
TransformHandleTypeInterface:$combining_linalg_op,
1768+
TransformHandleTypeInterface:$split_op,
1769+
TransformHandleTypeInterface:$combining_op,
17701770
TransformHandleTypeInterface:$for_op);
17711771

17721772
let builders = [
@@ -1784,7 +1784,7 @@ def TileReductionUsingForOp : Op<Transform_Dialect, "structured.tile_reduction_u
17841784
let extraClassDeclaration = [{
17851785
::mlir::DiagnosedSilenceableFailure applyToOne(
17861786
::mlir::transform::TransformRewriter &rewriter,
1787-
::mlir::linalg::LinalgOp target,
1787+
Operation *target,
17881788
::mlir::transform::ApplyToEachResultList &results,
17891789
::mlir::transform::TransformState &state);
17901790
}];

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2627,12 +2627,19 @@ void transform::TileReductionUsingForOp::build(
26272627
}
26282628

26292629
DiagnosedSilenceableFailure transform::TileReductionUsingForOp::applyToOne(
2630-
transform::TransformRewriter &rewriter, LinalgOp target,
2630+
transform::TransformRewriter &rewriter, Operation *target,
26312631
transform::ApplyToEachResultList &results,
26322632
transform::TransformState &state) {
26332633
rewriter.setInsertionPoint(target);
2634+
2635+
auto partialReductionOp = dyn_cast<PartialReductionOpInterface>(target);
2636+
if (!partialReductionOp) {
2637+
return emitSilenceableFailure(
2638+
target->getLoc(),
2639+
"Operation should implement PartialReductionOpInterface");
2640+
}
26342641
FailureOr<scf::SCFReductionTilingResult> result = scf::tileReductionUsingScf(
2635-
rewriter, cast<PartialReductionOpInterface>(target.getOperation()),
2642+
rewriter, partialReductionOp,
26362643
getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes())));
26372644

26382645
if (failed(result))

0 commit comments

Comments
 (0)