Skip to content

Commit 93b8734

Browse files
committed
Add test and address comments
1 parent 2234d6c commit 93b8734

File tree

4 files changed

+129
-1
lines changed

4 files changed

+129
-1
lines changed

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1063,7 +1063,7 @@ tileAndFuseFirstExtractUse(RewriterBase &rewriter, Diagnostic &diag,
10631063
resultNumber, offsets, sizes);
10641064

10651065
// Cleanup clone.
1066-
if (dyn_cast<LoopLikeOpInterface>(containingOp))
1066+
if (isa<LoopLikeOpInterface>(containingOp))
10671067
rewriter.eraseOp(tileableProducer);
10681068

10691069
return std::make_tuple(tileAndFuseResult->tiledOps, newContainingOp);

mlir/test/Dialect/Linalg/transform-op-fuse-into-containing.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,40 @@ module {
253253

254254
// -----
255255

256+
#map = affine_map<(d0) -> (d0 * 2)>
257+
#map1 = affine_map<(d0) -> (d0 * 4)>
258+
module {
259+
// CHECK-LABEL: func.func @fuse_tileable_op_no_dps
260+
func.func @fuse_tileable_op_no_dps(%arg0: tensor<4x4x4xf32>, %arg1: tensor<4x4x4xf32>) -> tensor<4x4x4xf32> {
261+
%0 = "test.tiling_no_dps_op"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4x4x4xf32>) -> tensor<4x4x4xf32>
262+
%1 = tensor.empty() : tensor<4x4x4xf32>
263+
// CHECK: scf.forall
264+
%2 = scf.forall (%arg2, %arg3, %arg4) in (4, 2, 1) shared_outs(%arg5 = %1) -> (tensor<4x4x4xf32>) {
265+
%3 = affine.apply #map(%arg3)
266+
%4 = affine.apply #map1(%arg4)
267+
// CHECK: "test.tiling_no_dps_op"
268+
// CHECK: "test.unregistered_op"
269+
%extracted_slice = tensor.extract_slice %0[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<4x4x4xf32> to tensor<1x2x4xf32>
270+
%5 = "test.unregistered_op"(%extracted_slice, %extracted_slice) : (tensor<1x2x4xf32>, tensor<1x2x4xf32>) -> tensor<1x2x4xf32>
271+
scf.forall.in_parallel {
272+
tensor.parallel_insert_slice %5 into %arg5[%arg2, %3, %4] [1, 2, 4] [1, 1, 1] : tensor<1x2x4xf32> into tensor<4x4x4xf32>
273+
}
274+
}
275+
return %2 : tensor<4x4x4xf32>
276+
}
277+
278+
module attributes {transform.with_named_sequence} {
279+
transform.named_sequence @__transform_main(%arg0: !transform.any_op {transform.readonly}) {
280+
%op = transform.structured.match ops{["test.tiling_no_dps_op"]} in %arg0 : (!transform.any_op) -> !transform.any_op
281+
%forall = transform.structured.match ops{["scf.forall"]} in %arg0 : (!transform.any_op) -> !transform.any_op
282+
%fused, %new_containing = transform.structured.fuse_into_containing_op %op into %forall : (!transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op)
283+
transform.yield
284+
}
285+
}
286+
}
287+
288+
// -----
289+
256290
module {
257291
// CHECK-LABEL: func.func @fuse_tileable_op_through_bbarg_inout_nested
258292
// CHECK-SAME: %[[ARG0:[0-9a-z]+]]: tensor<?x?x?xf32>

mlir/test/lib/Dialect/Test/TestOpDefs.cpp

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1051,6 +1051,84 @@ LogicalResult OpWithRefineTypeInterfaceOp::refineReturnTypes(
10511051
return success();
10521052
}
10531053

1054+
//===----------------------------------------------------------------------===//
1055+
// TilingNoDpsOp
1056+
//===----------------------------------------------------------------------===//
1057+
1058+
static Value getSlice(OpBuilder &builder, Location loc, Value source,
1059+
ArrayRef<OpFoldResult> offsets,
1060+
ArrayRef<OpFoldResult> sizes,
1061+
ArrayRef<OpFoldResult> strides) {
1062+
auto staticOffsets = getConstantIntValues(offsets);
1063+
auto staticSizes = getConstantIntValues(sizes);
1064+
auto staticStrides = getConstantIntValues(strides);
1065+
1066+
auto sourceShape = cast<ShapedType>(source.getType()).getShape();
1067+
if (staticSizes && ArrayRef(*staticSizes) == sourceShape)
1068+
return source;
1069+
1070+
return {mlir::tensor::ExtractSliceOp::create(builder, loc, source, offsets,
1071+
sizes, strides)};
1072+
}
1073+
1074+
static ShapedType getSliceType(ShapedType type, ArrayRef<OpFoldResult> sizes) {
1075+
auto staticSizes = getConstantIntValues(sizes);
1076+
if (staticSizes.has_value())
1077+
return type.cloneWith(*staticSizes, type.getElementType());
1078+
return nullptr;
1079+
}
1080+
1081+
SmallVector<Range> TilingNoDpsOp::getIterationDomain(OpBuilder &builder) {
1082+
auto shape = cast<ShapedType>(getResult().getType()).getShape();
1083+
auto zero = getAsIndexOpFoldResult(getContext(), 0);
1084+
auto one = getAsIndexOpFoldResult(getContext(), 1);
1085+
return llvm::map_to_vector(shape, [&](int64_t size) {
1086+
return Range{.offset = zero,
1087+
.size = getAsIndexOpFoldResult(getContext(), size),
1088+
.stride = one};
1089+
});
1090+
}
1091+
1092+
SmallVector<utils::IteratorType> TilingNoDpsOp::getLoopIteratorTypes() {
1093+
auto tensorType = cast<ShapedType>(getResult().getType());
1094+
SmallVector<utils::IteratorType> types(
1095+
static_cast<size_t>(tensorType.getRank()), utils::IteratorType::parallel);
1096+
return types;
1097+
}
1098+
1099+
FailureOr<TilingResult>
1100+
TilingNoDpsOp::getTiledImplementation(OpBuilder &builder,
1101+
ArrayRef<OpFoldResult> offsets,
1102+
ArrayRef<OpFoldResult> sizes) {
1103+
auto loc = getLoc();
1104+
auto strides = SmallVector<OpFoldResult>(
1105+
static_cast<size_t>(cast<ShapedType>(getOperand(0).getType()).getRank()),
1106+
getAsIndexOpFoldResult(getContext(), 1));
1107+
auto inputSlices = llvm::map_to_vector(getOperands(), [&](Value operand) {
1108+
return getSlice(builder, loc, operand, offsets, sizes, strides);
1109+
});
1110+
auto resultType =
1111+
getSliceType(cast<ShapedType>(getResult().getType()), sizes);
1112+
auto tiledOp = TilingNoDpsOp::create(builder, loc, TypeRange{resultType},
1113+
ValueRange(inputSlices));
1114+
return TilingResult{.tiledOps = {tiledOp},
1115+
.tiledValues = SmallVector<Value>{tiledOp.getResult()},
1116+
.generatedSlices =
1117+
map_to_vector(inputSlices, [](Value val) {
1118+
return val.getDefiningOp();
1119+
})};
1120+
}
1121+
1122+
LogicalResult TilingNoDpsOp::getResultTilePosition(
1123+
OpBuilder & builder, unsigned resultNumber,
1124+
ArrayRef<OpFoldResult> offsets, ArrayRef<OpFoldResult> sizes,
1125+
SmallVector<OpFoldResult> &resultOffsets,
1126+
SmallVector<OpFoldResult> &resultSizes) {
1127+
resultOffsets.assign(offsets.begin(), offsets.end());
1128+
resultSizes.assign(sizes.begin(), sizes.end());
1129+
return success();
1130+
}
1131+
10541132
//===----------------------------------------------------------------------===//
10551133
// OpWithShapedTypeInferTypeAdaptorInterfaceOp
10561134
//===----------------------------------------------------------------------===//

mlir/test/lib/Dialect/Test/TestOps.td

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
12
//===-- TestOps.td - Test dialect operation definitions ----*- tablegen -*-===//
23
//
34
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
@@ -30,6 +31,7 @@ include "mlir/Interfaces/InferTypeOpInterface.td"
3031
include "mlir/Interfaces/LoopLikeInterface.td"
3132
include "mlir/Interfaces/MemorySlotInterfaces.td"
3233
include "mlir/Interfaces/SideEffectInterfaces.td"
34+
include "mlir/Interfaces/TilingInterface.td"
3335
include "mlir/Interfaces/ValueBoundsOpInterface.td"
3436
include "mlir/Dialect/Bufferization/IR/BufferizableOpInterface.td"
3537
include "mlir/Dialect/Bufferization/IR/BufferizationTypeInterfaces.td"
@@ -2887,6 +2889,20 @@ def TestLinalgFillOp :
28872889
}];
28882890
}
28892891

2892+
//===----------------------------------------------------------------------===//
2893+
// Test TilingInterface.
2894+
//===----------------------------------------------------------------------===//
2895+
2896+
def Test_TilingNoDpsOp : TEST_Op<"tiling_no_dps_op",
2897+
[Pure, DeclareOpInterfaceMethods<TilingInterface,
2898+
["getIterationDomain",
2899+
"getLoopIteratorTypes",
2900+
"getResultTilePosition",
2901+
"getTiledImplementation"]>]> {
2902+
let arguments = (ins AnyRankedTensor:$lhs, AnyRankedTensor:$rhs);
2903+
let results = (outs AnyRankedTensor:$result);
2904+
}
2905+
28902906
//===----------------------------------------------------------------------===//
28912907
// Test NVVM RequiresSM trait.
28922908
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)