From 564cc9bc1e6437539e0b8bb6c9fc38d7498b1736 Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Mon, 16 Dec 2024 17:31:15 +0000 Subject: [PATCH] [mlir][Linalg] Allow PartialReductionOpInterface ops in tile_reduction_using_for --- .../Dialect/Linalg/TransformOps/LinalgTransformOps.td | 6 +++--- .../Linalg/TransformOps/LinalgTransformOps.cpp | 11 +++++++++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td index 2e713bca24efc..081bf9b6d3b23 100644 --- a/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td +++ b/mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td @@ -1765,8 +1765,8 @@ def TileReductionUsingForOp : Op:$tile_sizes); let results = (outs Variadic:$fill_op, - TransformHandleTypeInterface:$split_linalg_op, - TransformHandleTypeInterface:$combining_linalg_op, + TransformHandleTypeInterface:$split_op, + TransformHandleTypeInterface:$combining_op, TransformHandleTypeInterface:$for_op); let builders = [ @@ -1784,7 +1784,7 @@ def TileReductionUsingForOp : Op(target); + if (!partialReductionOp) { + return emitSilenceableFailure( + target->getLoc(), + "Operation should implement PartialReductionOpInterface"); + } FailureOr result = scf::tileReductionUsingScf( - rewriter, cast(target.getOperation()), + rewriter, partialReductionOp, getAsOpFoldResult(rewriter.getI64ArrayAttr(getTileSizes()))); if (failed(result))