Skip to content

Commit a257a06

Browse files
authored
[mlir][linalg-transform] dyn_cast DestinationStyleOpInterface and early return (#166299)
Use `dyn_cast` instead of `cast` and early return if op does not implement the `DestinationStyleOpInterface`. Before the change the following IR would cause a segfault when the transform interpreter is run, where `myop.a` and `myop.b` implement the `TilingInterface` and not the `DestinationStyleOpInterface`. Tried looking for ops in the upstream dialect that implement the `TilingInterface` and not the `DestinationStyleOpInterface` to add a test but could not find any. ```mlir module { func.func @fuse(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> { %mul = "myop.a"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32> %add = "myop.b"(%mul, %mul) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32> return %add : tensor<4x4x4xf32> } transform.sequence failures(propagate) { ^bb0(%func: !transform.any_op): %mul = transform.structured.match ops{["myop.a"]} in %func : (!transform.any_op) -> !transform.any_op %add = transform.structured.match ops{["myop.b"]} in %func : (!transform.any_op) -> !transform.any_op %loop, %tiled = transform.structured.tile_using_forall %add tile_sizes [1, 2, 4] : (!transform.any_op) -> (!transform.any_op, !transform.any_op) %mul_fused, %mul_containing = transform.structured.fuse_into_containing_op %mul into %tiled : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op) } } ```
1 parent 1a34007 commit a257a06

File tree

4 files changed

+81
-3
lines changed

4 files changed

+81
-3
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -997,8 +997,11 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
997997
// Iterate over the outputs of the producer and over the loop bbArgs and
998998
// check if any bbArg points to the same value as the producer output. In
999999
// such case, make the producer output point to the bbArg directly.
1000-
for (OpOperand &initOperandPtr :
1001-
cast<DestinationStyleOpInterface>(clone).getDpsInitsMutable()) {
1000+
auto dpsInterface = dyn_cast<DestinationStyleOpInterface>(clone);
1001+
if (!dpsInterface)
1002+
return;
1003+
1004+
for (OpOperand &initOperandPtr : dpsInterface.getDpsInitsMutable()) {
10021005
Value producerOperand =
10031006
clone->getOperand(initOperandPtr.getOperandNumber());
10041007
for (BlockArgument containerIterArg :
@@ -1060,7 +1063,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
10601063
resultNumber, offsets, sizes);
10611064

10621065
// Cleanup clone.
1063-
if (dyn_cast<LoopLikeOpInterface>(containingOp))
1066+
if (isa<LoopLikeOpInterface>(containingOp))
10641067
rewriter.eraseOp(tileableProducer);
10651068

10661069
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,40 @@ module {
253253

254254
// -----
255255

256+
#map = affine_map<(d0) -> (d0 * 2)>
257+
#map1 = affine_map<(d0) -> (d0 * 4)>
258+
module {
259+
// CHECK-LABEL: func.func @fuse_tileable_op_no_dps
260+
func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
261+
%0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
262+
%1 = tensor.empty() : tensor<4x4x4xf32>
263+
// CHECK: scf.forall
264+
%2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) {
265+
%3 = affine.apply #map(%arg3)
266+
%4 = affine.apply #map1(%arg4)
267+
// CHECK: "test.tiling_no_dps_op"
268+
// CHECK: "test.unregistered_op"
269+
%extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32>
270+
%5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
271+
scf.forall.in_parallel {
272+
tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32>
273+
}
274+
}
275+
return %2 : tensor<4x4x4xf32>
276+
}
277+
278+
module attributes {transform.with_named_sequence} {
279+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
280+
%op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
281+
%forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
282+
%fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
283+
transform.yield
284+
}
285+
}
286+
}
287+
288+
// -----
289+
256290
module {
257291
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
258292
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
10511051
return success();
10521052
}
10531053

1054+
//===----------------------------------------------------------------------===//
1055+
// TilingNoDpsOp
1056+
//===----------------------------------------------------------------------===//
1057+
1058+
SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
1059+
return {};
1060+
}
1061+
1062+
SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
1063+
return {};
1064+
}
1065+
1066+
FailureOr<TilingResult>
1067+
TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
1068+
ArrayRef<OpFoldResult> offsets,
1069+
ArrayRef<OpFoldResult> sizes) {
1070+
return failure();
1071+
}
1072+
1073+
LogicalResult TilingNoDpsOp::getResultTilePosition(
1074+
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
1075+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
1076+
SmallVector<OpFoldResult> &resultSizes) {
1077+
return failure();
1078+
}
1079+
10541080
//===----------------------------------------------------------------------===//
10551081
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
10561082
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
3030
include "mlir/Interfaces/LoopLikeInterface.td"
3131
include "mlir/Interfaces/MemorySlotInterfaces.td"
3232
include "mlir/Interfaces/SideEffectInterfaces.td"
33+
include "mlir/Interfaces/TilingInterface.td"
3334
include "mlir/Interfaces/ValueBoundsOpInterface.td"
3435
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
3536
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -2887,6 +2888,20 @@ def TestLinalgFillOp :
28872888
}];
28882889
}
28892890

2891+
//===----------------------------------------------------------------------===//
2892+
// Test TilingInterface.
2893+
//===----------------------------------------------------------------------===//
2894+
2895+
def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
2896+
[Pure, DeclareOpInterfaceMethods<TilingInterface,
2897+
["getIterationDomain",
2898+
"getLoopIteratorTypes",
2899+
"getResultTilePosition",
2900+
"getTiledImplementation"]>]> {
2901+
let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
2902+
let results = (outs AnyRankedTensor:$result);
2903+
}
2904+
28902905
//===----------------------------------------------------------------------===//
28912906
// Test NVVM RequiresSM trait.
28922907
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)