Skip to content

Commit 72dd81d

Browse files
committed
Add test and address comments
1 parent 2234d6c commit 72dd81d

File tree

4 files changed

+76
-1
lines changed

4 files changed

+76
-1
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
10631063
resultNumber, offsets, sizes);
10641064

10651065
// Cleanup clone.
1066-
if (dyn_cast<LoopLikeOpInterface>(containingOp))
1066+
if (isa<LoopLikeOpInterface>(containingOp))
10671067
rewriter.eraseOp(tileableProducer);
10681068

10691069
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,40 @@ module {
253253

254254
// -----
255255

256+
#map = affine_map<(d0) -> (d0 * 2)>
257+
#map1 = affine_map<(d0) -> (d0 * 4)>
258+
module {
259+
// CHECK-LABEL: func.func @fuse_tileable_op_no_dps
260+
func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
261+
%0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
262+
%1 = tensor.empty() : tensor<4x4x4xf32>
263+
// CHECK: scf.forall
264+
%2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) {
265+
%3 = affine.apply #map(%arg3)
266+
%4 = affine.apply #map1(%arg4)
267+
// CHECK: "test.tiling_no_dps_op"
268+
// CHECK: "test.unregistered_op"
269+
%extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32>
270+
%5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
271+
scf.forall.in_parallel {
272+
tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32>
273+
}
274+
}
275+
return %2 : tensor<4x4x4xf32>
276+
}
277+
278+
module attributes {transform.with_named_sequence} {
279+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
280+
%op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
281+
%forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
282+
%fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
283+
transform.yield
284+
}
285+
}
286+
}
287+
288+
// -----
289+
256290
module {
257291
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
258292
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,32 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
10511051
return success();
10521052
}
10531053

1054+
//===----------------------------------------------------------------------===//
1055+
// TilingNoDpsOp
1056+
//===----------------------------------------------------------------------===//
1057+
1058+
SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
1059+
return {};
1060+
}
1061+
1062+
SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
1063+
return {};
1064+
}
1065+
1066+
FailureOr<TilingResult>
1067+
TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
1068+
ArrayRef<OpFoldResult> offsets,
1069+
ArrayRef<OpFoldResult> sizes) {
1070+
return failure();
1071+
}
1072+
1073+
LogicalResult TilingNoDpsOp::getResultTilePosition(
1074+
OpBuilder &builder, unsigned resultNumber, ArrayRef<OpFoldResult> offsets,
1075+
ArrayRef<OpFoldResult> sizes, SmallVector<OpFoldResult> &resultOffsets,
1076+
SmallVector<OpFoldResult> &resultSizes) {
1077+
return failure();
1078+
}
1079+
10541080
//===----------------------------------------------------------------------===//
10551081
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
10561082
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
3030
include "mlir/Interfaces/LoopLikeInterface.td"
3131
include "mlir/Interfaces/MemorySlotInterfaces.td"
3232
include "mlir/Interfaces/SideEffectInterfaces.td"
33+
include "mlir/Interfaces/TilingInterface.td"
3334
include "mlir/Interfaces/ValueBoundsOpInterface.td"
3435
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
3536
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -2887,6 +2888,20 @@ def TestLinalgFillOp :
28872888
}];
28882889
}
28892890

2891+
//===----------------------------------------------------------------------===//
2892+
// Test TilingInterface.
2893+
//===----------------------------------------------------------------------===//
2894+
2895+
def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
2896+
[Pure, DeclareOpInterfaceMethods<TilingInterface,
2897+
["getIterationDomain",
2898+
"getLoopIteratorTypes",
2899+
"getResultTilePosition",
2900+
"getTiledImplementation"]>]> {
2901+
let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
2902+
let results = (outs AnyRankedTensor:$result);
2903+
}
2904+
28902905
//===----------------------------------------------------------------------===//
28912906
// Test NVVM RequiresSM trait.
28922907
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)