Skip to content

Commit bbca3c0

Browse files
committed
FuseOp: add use_forall argument that generates scf.forall loops
1 parent cdc8e8d commit bbca3c0

File tree

5 files changed

+57
-6
lines changed

5 files changed

+57
-6
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -410,13 +410,15 @@ def FuseOp : Op<Transform_Dialect, "structured.fuse",
410410
(ins TransformHandleTypeInterface:$target,
411411
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_sizes,
412412
DefaultValuedAttr<I64ArrayAttr, "{}">:$tile_interchange,
413-
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup);
413+
DefaultValuedAttr<BoolAttr, "false">:$apply_cleanup,
414+
DefaultValuedAttr<BoolAttr, "false">:$use_forall);
414415
let results = (outs TransformHandleTypeInterface:$transformed,
415416
Variadic<TransformHandleTypeInterface>:$loops);
416417

417418
let assemblyFormat = [{
418419
$target ($tile_sizes^)? (`interchange` $tile_interchange^)?
419-
(`apply_cleanup` `=` $apply_cleanup^)? attr-dict
420+
(`apply_cleanup` `=` $apply_cleanup^)?
421+
(`use_forall` `=` $use_forall^)? attr-dict
420422
`:` functional-type(operands, results)
421423
}];
422424
let hasVerifier = 1;

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,10 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
637637

638638
scf::SCFTilingOptions tilingOptions;
639639
tilingOptions.interchangeVector = tileInterchange;
640+
bool useForall = getUseForall();
641+
tilingOptions.setLoopType(useForall
642+
? scf::SCFTilingOptions::LoopType::ForallOp
643+
: scf::SCFTilingOptions::LoopType::ForOp);
640644
SmallVector<OpFoldResult> tileSizesOfr =
641645
getAsIndexOpFoldResult(rewriter.getContext(), tileSizes);
642646
tilingOptions = tilingOptions.setTileSizes(tileSizesOfr);
@@ -652,9 +656,11 @@ transform::FuseOp::apply(transform::TransformRewriter &rewriter,
652656
tileAndFuseOptions.cleanupPatterns = std::move(patterns);
653657
}
654658

659+
size_t numLoops =
660+
useForall ? 1 : tileSizes.size() - llvm::count(tileSizes, 0);
655661
LogicalResult result = applyTilingToAll(
656-
rewriter, getOperation(), state.getPayloadOps(getTarget()),
657-
tileSizes.size() - llvm::count(tileSizes, 0), transformResults,
662+
rewriter, getOperation(), state.getPayloadOps(getTarget()), numLoops,
663+
transformResults,
658664
[&](TilingInterface tilingInterfaceOp)
659665
-> FailureOr<scf::SCFTileAndFuseResult> {
660666
return tileConsumerAndFuseProducersUsingSCF(rewriter, tilingInterfaceOp,
@@ -676,7 +682,8 @@ LogicalResult transform::FuseOp::verify() {
676682

677683
SmallVector<int64_t> sizes =
678684
extractFromIntegerArrayAttr<int64_t>(getTileSizes());
679-
size_t numExpectedLoops = sizes.size() - llvm::count(sizes, 0);
685+
size_t numExpectedLoops =
686+
getUseForall() ? 1 : sizes.size() - llvm::count(sizes, 0);
680687
if (numExpectedLoops != getNumResults() - 1)
681688
return emitOpError() << "expects " << numExpectedLoops << " loop results";
682689

mlir/python/mlir/dialects/transform/structured.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ def __init__(
147147
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
148148
tile_interchange: OptionalIntList = None,
149149
apply_cleanup: Optional[bool] = False,
150+
use_forall: Optional[bool] = False,
150151
loc=None,
151152
ip=None,
152153
):
@@ -160,6 +161,7 @@ def __init__(
160161
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
161162
tile_interchange: OptionalIntList = None,
162163
apply_cleanup: Optional[bool] = False,
164+
use_forall: Optional[bool] = False,
163165
loc=None,
164166
ip=None,
165167
):
@@ -173,14 +175,15 @@ def __init__(
173175
tile_sizes: Optional[Union[DynamicIndexList, ArrayAttr]] = None,
174176
tile_interchange: OptionalIntList = None,
175177
apply_cleanup: Optional[bool] = False,
178+
use_forall: Optional[bool] = False,
176179
loc=None,
177180
ip=None,
178181
):
179182
tile_sizes = tile_sizes if tile_sizes else []
180183
tile_interchange = tile_interchange if tile_interchange else []
181184
_, tile_sizes, _ = _dispatch_dynamic_index_list(tile_sizes)
182185
_, tile_interchange, _ = _dispatch_dynamic_index_list(tile_interchange)
183-
num_loops = sum(0 if v == 0 else 1 for v in tile_sizes)
186+
num_loops = 1 if use_forall else sum(0 if v == 0 else 1 for v in tile_sizes)
184187

185188
if isinstance(loop_types_or_target, (Operation, Value, OpView)):
186189
loop_types = [transform.AnyOpType.get()] * num_loops
@@ -200,6 +203,7 @@ def __init__(
200203
tile_sizes=tile_sizes,
201204
tile_interchange=tile_interchange,
202205
apply_cleanup=apply_cleanup,
206+
use_forall=use_forall,
203207
loc=loc,
204208
ip=ip,
205209
)

mlir/test/Dialect/Linalg/transform-op-fuse.mlir

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,31 @@ module attributes {transform.with_named_sequence} {
5757

5858
// -----
5959

60+
// CHECK-LABEL: func.func @fuse_unary_forall
61+
func.func @fuse_unary_forall(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
62+
63+
// CHECK: %[[RES:.*]] = scf.forall
64+
// CHECK: linalg.exp
65+
// CHECK: linalg.add
66+
// CHECK: return %[[RES]]
67+
%0 = linalg.exp ins(%arg0 : tensor<?x?xf32>)
68+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
69+
%1 = linalg.add ins(%0, %arg0 : tensor<?x?xf32>, tensor<?x?xf32>)
70+
outs(%arg1: tensor<?x?xf32>) -> tensor<?x?xf32>
71+
return %1 : tensor<?x?xf32>
72+
}
73+
74+
module attributes {transform.with_named_sequence} {
75+
transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) {
76+
%0 = transform.structured.match ops{["linalg.add"]} in %arg1 : (!transform.any_op) -> !transform.any_op
77+
%1, %loop = transform.structured.fuse %0 [32, 32] use_forall = true
78+
: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
79+
transform.yield
80+
}
81+
}
82+
83+
// -----
84+
6085
// CHECK-LABEL: func.func @interchange_reduction
6186
// CHECK-SAME: (%[[INPUT:.+]]: tensor<12x7x25xf32>)
6287
func.func @interchange_reduction(%input: tensor<12x7x25xf32>) -> tensor<12x25xf32> {

mlir/test/python/dialects/transform_structured_ext.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,19 @@ def testFuseOpCompact(target):
114114
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
115115

116116

117+
@run
118+
@create_sequence
119+
def testFuseOpCompactForall(target):
120+
structured.FuseOp(
121+
target, tile_sizes=[4, 8], apply_cleanup=True, use_forall=True,
122+
)
123+
# CHECK-LABEL: TEST: testFuseOpCompact
124+
# CHECK: transform.sequence
125+
# CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
126+
# CHECK-SAME: apply_cleanup = true use_forall = true
127+
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
128+
129+
117130
@run
118131
@create_sequence
119132
def testFuseOpNoArg(target):

0 commit comments

Comments
 (0)