Skip to content

Commit ddec497

Browse files
add HoistLoopInvariantCodeMotionOp
1 parent 24c5926 commit ddec497

File tree

3 files changed

+109
-0
lines changed

3 files changed

+109
-0
lines changed

mlir/include/mlir/Dialect/Transform/LoopExtension/LoopExtensionOps.td

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,4 +73,60 @@ def HoistLoopInvariantSubsetsOp
7373
}];
7474
}
7575

76+
def HoistLoopInvariantCodeMotionOp
77+
: TransformDialectOp<"loop.hoist_loop_invariant_code_motion",
78+
[TransformOpInterface, TransformEachOpTrait,
79+
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
80+
ReportTrackingListenerFailuresOpTrait]> {
81+
let summary = "Hoist loop invariant instructions outside of the loop";
82+
let description = [{
83+
This transform hoists loop-invariant instructions out of the targeted
84+
loop-like op.
85+
86+
Example:
87+
```
88+
%m = memref.alloc() : memref<10xf32>
89+
%cf7 = arith.constant 7.0 : f32
90+
%cf8 = arith.constant 8.0 : f32
91+
92+
affine.for %arg0 = 0 to 10 {
93+
affine.for %arg1 = 0 to 10 {
94+
%v0 = arith.addf %cf7, %cf8 : f32
95+
}
96+
}
97+
```
98+
Is transformed to:
99+
```
100+
%alloc = memref.alloc() : memref<10xf32>
101+
%cst = arith.constant 7.000000e+00 : f32
102+
%cst_0 = arith.constant 8.000000e+00 : f32
103+
%0 = arith.addf %cst, %cst_0 : f32
104+
affine.for %arg0 = 0 to 10 {
105+
}
106+
affine.for %arg0 = 0 to 10 {
107+
}
108+
```
109+
110+
loop-invariant instructions are hoisted only if they are pure ops and
111+
they are ancestors of the parent regionOp of all operands.
112+
113+
This transform reads the target handle and modifies the payload. This
114+
transform does not invalidate any handles, but loop-like ops are replaced
115+
with new loop-like ops when a loop-invariant op is hoisted. The transform
116+
rewriter updates all handles accordingly.
117+
}];
118+
119+
let arguments = (ins TransformHandleTypeInterface:$target);
120+
let results = (outs);
121+
let assemblyFormat = "$target attr-dict `:` type($target)";
122+
123+
let extraClassDeclaration = [{
124+
::mlir::DiagnosedSilenceableFailure applyToOne(
125+
::mlir::transform::TransformRewriter &rewriter,
126+
::mlir::LoopLikeOpInterface loopLikeOp,
127+
::mlir::transform::ApplyToEachResultList &results,
128+
::mlir::transform::TransformState &state);
129+
}];
130+
}
131+
76132
#endif // MLIR_DIALECT_TRANSFORM_LOOPEXTENSION_LOOPEXTENSIONOPS

mlir/lib/Dialect/Transform/LoopExtension/LoopExtensionOps.cpp

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,22 @@ void transform::HoistLoopInvariantSubsetsOp::getEffects(
3232
transform::onlyReadsHandle(getTargetMutable(), effects);
3333
transform::modifiesPayload(effects);
3434
}
35+
36+
//===----------------------------------------------------------------------===//
37+
// HoistLoopInvariantCodeMotionOp
38+
//===----------------------------------------------------------------------===//
39+
40+
DiagnosedSilenceableFailure
41+
transform::HoistLoopInvariantCodeMotionOp::applyToOne(
42+
transform::TransformRewriter &rewriter, LoopLikeOpInterface loopLikeOp,
43+
transform::ApplyToEachResultList &results,
44+
transform::TransformState &state) {
45+
(void)moveLoopInvariantCode(loopLikeOp);
46+
return DiagnosedSilenceableFailure::success();
47+
}
48+
49+
void transform::HoistLoopInvariantCodeMotionOp::getEffects(
50+
SmallVectorImpl<MemoryEffects::EffectInstance> &effects) {
51+
transform::onlyReadsHandle(getTargetMutable(), effects);
52+
transform::modifiesPayload(effects);
53+
}

mlir/test/Dialect/Transform/test-loop-transforms.mlir

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,3 +79,37 @@ module attributes {transform.with_named_sequence} {
7979
transform.yield
8080
}
8181
}
82+
83+
// -----
84+
85+
func.func @nested_loops_both_having_invariant_code_transform() {
86+
%m = memref.alloc() : memref<10xf32>
87+
%cf7 = arith.constant 7.0 : f32
88+
%cf8 = arith.constant 8.0 : f32
89+
affine.for %arg0 = 0 to 10 {
90+
%v0 = arith.addf %cf7, %cf8 : f32
91+
affine.for %arg1 = 0 to 10 {
92+
%v1 = arith.addf %v0, %cf8 : f32
93+
affine.store %v0, %m[%arg0] : memref<10xf32>
94+
}
95+
}
96+
97+
// CHECK: memref.alloc() : memref<10xf32>
98+
// CHECK-NEXT: %[[CST0:.*]] = arith.constant 7.000000e+00 : f32
99+
// CHECK-NEXT: %[[CST1:.*]] = arith.constant 8.000000e+00 : f32
100+
// CHECK-NEXT: %[[ADD0:.*]] = arith.addf %[[CST0]], %[[CST1]] : f32
101+
// CHECK-NEXT: arith.addf %[[ADD0]], %[[CST1]] : f32
102+
// CHECK-NEXT: affine.for
103+
// CHECK-NEXT: affine.for
104+
// CHECK-NEXT: affine.store
105+
return
106+
}
107+
108+
module attributes {transform.with_named_sequence} {
109+
transform.named_sequence @__transform_main(
110+
%arg0: !transform.any_op {transform.readonly}) {
111+
%loop = transform.structured.match interface{LoopLikeInterface} in %arg0 : (!transform.any_op) -> !transform.any_op
112+
transform.loop.hoist_loop_invariant_code_motion %loop : !transform.any_op
113+
transform.yield
114+
}
115+
}

0 commit comments

Comments
 (0)