Skip to content

Commit 74f0660

Browse files
[mlir][Transform] NFC - Pass TransformState as an argument to applyToOne methods
This will allow implementing state-dependent behavior in the future. Differential Revision: https://reviews.llvm.org/D128327
1 parent 706e89d commit 74f0660

File tree

7 files changed

+41
-26
lines changed

7 files changed

+41
-26
lines changed

mlir/include/mlir/Dialect/Linalg/TransformOps/LinalgTransformOps.td

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def DecomposeOp : Op<Transform_Dialect, "structured.decompose",
3333

3434
let extraClassDeclaration = [{
3535
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
36-
::mlir::linalg::LinalgOp target);
36+
::mlir::linalg::LinalgOp target, TransformState &state);
3737
}];
3838
}
3939

@@ -74,7 +74,7 @@ def GeneralizeOp : Op<Transform_Dialect, "structured.generalize",
7474

7575
let extraClassDeclaration = [{
7676
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
77-
::mlir::linalg::LinalgOp target);
77+
::mlir::linalg::LinalgOp target, TransformState &state);
7878
}];
7979
}
8080

@@ -96,7 +96,7 @@ def InterchangeOp : Op<Transform_Dialect, "structured.interchange",
9696

9797
let extraClassDeclaration = [{
9898
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
99-
::mlir::linalg::LinalgOp target);
99+
::mlir::linalg::LinalgOp target, TransformState &state);
100100
}];
101101
}
102102

@@ -124,7 +124,7 @@ def PadOp : Op<Transform_Dialect, "structured.pad",
124124

125125
let extraClassDeclaration = [{
126126
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
127-
::mlir::linalg::LinalgOp target);
127+
::mlir::linalg::LinalgOp target, TransformState &state);
128128
}];
129129
}
130130

@@ -149,7 +149,7 @@ def ScalarizeOp : Op<Transform_Dialect, "structured.scalarize",
149149

150150
let extraClassDeclaration = [{
151151
::mlir::FailureOr<::mlir::linalg::LinalgOp> applyToOne(
152-
::mlir::linalg::LinalgOp target);
152+
::mlir::linalg::LinalgOp target, TransformState &state);
153153
}];
154154
}
155155

@@ -218,7 +218,7 @@ def SplitReductionOp : Op<Transform_Dialect, "structured.split_reduction",
218218

219219
let extraClassDeclaration = [{
220220
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
221-
::mlir::linalg::LinalgOp target);
221+
::mlir::linalg::LinalgOp target, TransformState &state);
222222
}];
223223
}
224224

@@ -275,7 +275,8 @@ def VectorizeOp : Op<Transform_Dialect, "structured.vectorize",
275275
let assemblyFormat = "$target attr-dict";
276276

277277
let extraClassDeclaration = [{
278-
::mlir::FailureOr<Operation *> applyToOne(::mlir::Operation *target);
278+
::mlir::FailureOr<Operation *> applyToOne(
279+
::mlir::Operation *target, TransformState &state);
279280
}];
280281
}
281282

mlir/include/mlir/Dialect/SCF/TransformOps/SCFTransformOps.td

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,8 @@ def LoopPeelOp : Op<Transform_Dialect, "loop.peel",
8888
let assemblyFormat = "$target attr-dict";
8989

9090
let extraClassDeclaration = [{
91-
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
91+
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
92+
::mlir::scf::ForOp loop, TransformState &state);
9293
}];
9394
}
9495

@@ -115,7 +116,8 @@ def LoopPipelineOp : Op<Transform_Dialect, "loop.pipeline",
115116
let assemblyFormat = "$target attr-dict";
116117

117118
let extraClassDeclaration = [{
118-
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(::mlir::scf::ForOp loop);
119+
::mlir::FailureOr<::mlir::scf::ForOp> applyToOne(
120+
::mlir::scf::ForOp loop, TransformState &state);
119121
}];
120122
}
121123

@@ -137,7 +139,8 @@ def LoopUnrollOp : Op<Transform_Dialect, "loop.unroll",
137139
let assemblyFormat = "$target attr-dict";
138140

139141
let extraClassDeclaration = [{
140-
::mlir::LogicalResult applyToOne(::mlir::scf::ForOp loop);
142+
::mlir::LogicalResult applyToOne(
143+
::mlir::scf::ForOp loop, TransformState &state);
141144
}];
142145
}
143146

mlir/include/mlir/Dialect/Transform/IR/TransformInterfaces.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -582,9 +582,9 @@ class PossibleTopLevelTransformOpTrait
582582
/// transformation to a single operation handle and producing one or multiple
583583
/// operation handles.
584584
/// The op must implement a method with one of the following signatures:
585-
/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy)
586-
/// - FailureOr<SmallVector<convertible-to-Operation*>> applyToOne(OpTy)
587-
/// - LogicalResult applyToOne(OpTy)
585+
/// - FailureOr<convertible-to-Operation*> applyToOne(OpTy, state)
586+
/// - FailureOr<SmallVector<convertible-to-Operation*>>applyToOne(OpTy, state)
587+
/// - LogicalResult applyToOne(OpTy, state)
588588
/// to perform a transformation that is applied in turn to all payload IR
589589
/// operations that correspond to the handle of the transform IR operation.
590590
/// In the functions above, OpTy is either Operation * or a concrete payload IR
@@ -811,7 +811,7 @@ mlir::transform::TransformEachOpTrait<OpTy>::apply(
811811
// produced.
812812
DiagnosedSilenceableFailure result = detail::applyTransformToEach(
813813
targets, results, [&](TransformOpType specificOp) {
814-
return static_cast<OpTy *>(this)->applyToOne(specificOp);
814+
return static_cast<OpTy *>(this)->applyToOne(specificOp, state);
815815
});
816816
if (!result.succeeded())
817817
return result;

mlir/lib/Dialect/Linalg/TransformOps/LinalgTransformOps.cpp

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ static FailureOr<LinalgOp> tryApply(Operation *operation, Args &&...args) {
7676
// DecomposeOp
7777
//===----------------------------------------------------------------------===//
7878

79-
FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target) {
79+
FailureOr<LinalgOp> transform::DecomposeOp::applyToOne(LinalgOp target,
80+
TransformState &state) {
8081
FailureOr<LinalgOp> windowed =
8182
tryApply<DownscaleSizeOneWindowed2DConvolution>(target);
8283
if (succeeded(windowed))
@@ -220,7 +221,8 @@ LogicalResult transform::FuseOp::verify() {
220221
// GeneralizeOp
221222
//===----------------------------------------------------------------------===//
222223

223-
FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
224+
FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target,
225+
TransformState &state) {
224226
// Exit early if no transformation is needed.
225227
if (isa<GenericOp>(target))
226228
return target;
@@ -236,7 +238,8 @@ FailureOr<LinalgOp> transform::GeneralizeOp::applyToOne(LinalgOp target) {
236238
// InterchangeOp
237239
//===----------------------------------------------------------------------===//
238240

239-
FailureOr<LinalgOp> transform::InterchangeOp::applyToOne(LinalgOp target) {
241+
FailureOr<LinalgOp>
242+
transform::InterchangeOp::applyToOne(LinalgOp target, TransformState &state) {
240243
SmallVector<unsigned> interchangeVector =
241244
extractUIntArray(getIteratorInterchange());
242245
// Exit early if no transformation is needed.
@@ -272,7 +275,8 @@ LogicalResult transform::InterchangeOp::verify() {
272275
// PadOp
273276
//===---------------------------------------------------------------------===//
274277

275-
FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target) {
278+
FailureOr<LinalgOp> transform::PadOp::applyToOne(LinalgOp target,
279+
TransformState &state) {
276280
// Convert the integer packing flags to booleans.
277281
SmallVector<bool> packPaddings;
278282
for (int64_t packPadding : extractI64Array(getPackPaddings()))
@@ -377,7 +381,8 @@ LogicalResult transform::PadOp::verify() {
377381
// ScalarizeOp
378382
//===----------------------------------------------------------------------===//
379383

380-
FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
384+
FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target,
385+
TransformState &state) {
381386
LinalgTilingOptions tilingOptions;
382387
tilingOptions.scalarizeDynamicDims();
383388
// Tiling with "scalarize_dyn_dims" actually sets the same lambda as the tile
@@ -399,7 +404,8 @@ FailureOr<LinalgOp> transform::ScalarizeOp::applyToOne(LinalgOp target) {
399404
//===----------------------------------------------------------------------===//
400405

401406
FailureOr<SmallVector<Operation *>>
402-
transform::SplitReductionOp::applyToOne(LinalgOp target) {
407+
transform::SplitReductionOp::applyToOne(LinalgOp target,
408+
TransformState &state) {
403409
ControlSplitReductionFn splitFn = [&](LinalgOp) {
404410
return std::pair<int64_t, unsigned>(getSplitFactor(),
405411
getInsertSplitDimension());
@@ -455,7 +461,8 @@ void TileOp::print(OpAsmPrinter &p) {
455461
// VectorizeOp
456462
//===----------------------------------------------------------------------===//
457463

458-
FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target) {
464+
FailureOr<Operation *> VectorizeOp::applyToOne(Operation *target,
465+
TransformState &state) {
459466
if (!target->hasTrait<OpTrait::IsIsolatedFromAbove>()) {
460467
InFlightDiagnostic diag = emitOpError()
461468
<< "applies only to isolated-from-above targets";

mlir/lib/Dialect/SCF/TransformOps/SCFTransformOps.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,8 @@ transform::LoopOutlineOp::apply(transform::TransformResults &results,
127127
// LoopPeelOp
128128
//===----------------------------------------------------------------------===//
129129

130-
FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop) {
130+
FailureOr<scf::ForOp> transform::LoopPeelOp::applyToOne(scf::ForOp loop,
131+
TransformState &state) {
131132
scf::ForOp result;
132133
IRRewriter rewriter(loop->getContext());
133134
LogicalResult status =
@@ -180,7 +181,8 @@ loopScheduling(scf::ForOp forOp,
180181
}
181182
}
182183

183-
FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
184+
FailureOr<scf::ForOp>
185+
transform::LoopPipelineOp::applyToOne(scf::ForOp loop, TransformState &state) {
184186
scf::PipeliningOption options;
185187
options.getScheduleFn =
186188
[this](scf::ForOp forOp,
@@ -203,7 +205,8 @@ FailureOr<scf::ForOp> transform::LoopPipelineOp::applyToOne(scf::ForOp loop) {
203205
// LoopUnrollOp
204206
//===----------------------------------------------------------------------===//
205207

206-
LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop) {
208+
LogicalResult transform::LoopUnrollOp::applyToOne(scf::ForOp loop,
209+
TransformState &state) {
207210
if (failed(loopUnrollByFactor(loop, getFactor())))
208211
return reportUnknownTransformError(loop);
209212
return success();

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,8 @@ DiagnosedSilenceableFailure mlir::test::TestEmitRemarkAndEraseOperandOp::apply(
227227
}
228228

229229
FailureOr<SmallVector<Operation *>>
230-
mlir::test::TestWrongNumberOfResultsOp::applyToOne(Operation *) {
230+
mlir::test::TestWrongNumberOfResultsOp::applyToOne(
231+
Operation *, transform::TransformState &state) {
231232
return SmallVector<Operation *>{};
232233
}
233234

mlir/test/lib/Dialect/Transform/TestTransformDialectExtension.td

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ def TestWrongNumberOfResultsOp
140140
let cppNamespace = "::mlir::test";
141141
let extraClassDeclaration = [{
142142
::mlir::FailureOr<::llvm::SmallVector<::mlir::Operation *>> applyToOne(
143-
::mlir::Operation *target);
143+
::mlir::Operation *target, transform::TransformState &state);
144144
}];
145145
}
146146

0 commit comments

Comments
 (0)