diff --git a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp index 3a433825fd31a..aa8206347e9b1 100644 --- a/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp +++ b/mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp @@ -997,8 +997,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, // Iterate over the outputs of the producer and over the loop bbArgs and // check if any bbArg points to the same value as the producer output. In // such case, make the producer output point to the bbArg directly. - for (OpOperand &initOperandPtr : - cast(clone).getDpsInitsMutable()) { + auto dpsInterface = dyn_cast(clone); + if (!dpsInterface) + return; + + for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) { Value producerOperand = clone->getOperand(initOperandPtr.getOperandNumber()); for (BlockArgument containerIterArg : @@ -1060,7 +1063,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag, resultNumber, offsets, sizes); // Cleanup clone. - if (dyn_cast(containingOp)) + if (isa(containingOp)) rewriter.eraseOp(tileableProducer); return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp); diff --git a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir index e5216089692b4..ab38f9f2f5943 100644 --- a/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir +++ b/mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir @@ -253,6 +253,40 @@ module { // ----- +#map = affine_map<(d0) -> (d0 * 2)> +#map1 = affine_map<(d0) -> (d0 * 4)> +module { + // CHECK-LABEL: func.func @fuse_tileable_op_no_dps + func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> { + %0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32> + %1 = tensor.empty() : tensor<4x4x4xf32> + // CHECK: scf.forall + %2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) { + %3 = affine.apply #map(%arg3) + %4 = affine.apply #map1(%arg4) + // CHECK: "test.tiling_no_dps_op" + // CHECK: "test.unregistered_op" + %extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32> + %5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32> + scf.forall.in_parallel { + tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32> + } + } + return %2 : tensor<4x4x4xf32> + } + + module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) { + %op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op + %fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) + transform.yield + } + } +} + +// ----- + module { // CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested // CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 4d4ec02546bc7..e21cf94f84b66 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -1051,6 +1051,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes( return success(); } +//===----------------------------------------------------------------------===// +// TilingNoDpsOp +//===----------------------------------------------------------------------===// + +SmallVector TilingNoDpsOp::getIterationDomain(OpBuilder &builder) { + return {}; +} + +SmallVector TilingNoDpsOp::getLoopIteratorTypes() { + return {}; +} + +FailureOr +TilingNoDpsOp::getTiledImplementation(OpBuilder &builder, + ArrayRef offsets, + ArrayRef sizes) { + return failure(); +} + +LogicalResult TilingNoDpsOp::getResultTilePosition( + OpBuilder &builder, unsigned resultNumber, ArrayRef offsets, + ArrayRef sizes, SmallVector &resultOffsets, + SmallVector &resultSizes) { + return failure(); +} + //===----------------------------------------------------------------------===// // OpWithShapedTypeInferTypeAdaptorInterfaceOp //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index a3430ba49a291..620d950c0d2af 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/MemorySlotInterfaces.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/Interfaces/TilingInterface.td" include "mlir/Interfaces/ValueBoundsOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td" include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td" @@ -2887,6 +2888,20 @@ def TestLinalgFillOp : }]; } +//===----------------------------------------------------------------------===// +// Test TilingInterface. +//===----------------------------------------------------------------------===// + +def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op", + [Pure, DeclareOpInterfaceMethods]> { + let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs); + let results = (outs AnyRankedTensor:$result); +} + //===----------------------------------------------------------------------===// // Test NVVM RequiresSM trait. //===----------------------------------------------------------------------===//