Skip to content

Commit 0ac44d7

Browse files
authored
Bump llvm - transform.apply_registered_pass options params (#1055)
* llvm/llvm-project#139340 ``` sed -i "s/isZeroIndex/isZeroInteger/g" **/*.h sed -i "s/isZeroIndex/isZeroInteger/g" **/*.cpp ``` * llvm/llvm-project#141466 & llvm/llvm-project#141019 * Add `BufferizationState &state` to `bufferize` and `getBuffer` * llvm/llvm-project#143159 & llvm/llvm-project#142683 & llvm/llvm-project#143779 * Updates to `transform.apply_registered_pass` and its Python-bindings * llvm/llvm-project#143217 * `tilingResult->mergeResult.replacements` -> `tilingResult->replacements` * llvm/llvm-project#140559 & llvm/llvm-project#143871 * Change CHECK lines which expected `amx.` ops to `llvm.call_intrinsic`s & fix which enables conversion again.
1 parent 75b113b commit 0ac44d7

File tree

15 files changed

+41
-94
lines changed

15 files changed

+41
-94
lines changed

build_tools/llvm_version.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
faf5d747f174cc9d714839f0d3bce1a783eac2ac
1+
d698ede748e66f5519cb8481abc2df89a994a059

lib/TPP/Dialect/Check/BufferizableOpInterfaceImpl.cpp

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@ struct ExpectTrueLayoutInterface
4848
}
4949

5050
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
51-
const BufferizationOptions &options) const {
51+
const BufferizationOptions &options,
52+
BufferizationState &state) const {
5253
check::ExpectTrueOp expectTrueOp = cast<check::ExpectTrueOp>(op);
5354

5455
FailureOr<Value> maybeSrcBuffer =
55-
getBuffer(rewriter, expectTrueOp.getOperand(), options);
56+
getBuffer(rewriter, expectTrueOp.getOperand(), options, state);
5657
if (failed(maybeSrcBuffer))
5758
return failure();
5859
Value srcBuffer = *maybeSrcBuffer;
@@ -91,16 +92,17 @@ struct ExpectAlmostEqLayoutInterface
9192
}
9293

9394
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
94-
const BufferizationOptions &options) const {
95+
const BufferizationOptions &options,
96+
BufferizationState &state) const {
9597
check::ExpectAlmostEqOp almostEqOp = cast<check::ExpectAlmostEqOp>(op);
9698
FailureOr<Value> maybeFirstBuffer =
97-
getBuffer(rewriter, almostEqOp.getLhs(), options);
99+
getBuffer(rewriter, almostEqOp.getLhs(), options, state);
98100
if (failed(maybeFirstBuffer))
99101
return failure();
100102
Value firstBuffer = *maybeFirstBuffer;
101103

102104
FailureOr<Value> maybeSecondBuffer =
103-
getBuffer(rewriter, almostEqOp.getRhs(), options);
105+
getBuffer(rewriter, almostEqOp.getRhs(), options, state);
104106
if (failed(maybeSecondBuffer))
105107
return failure();
106108
Value secondBuffer = *maybeSecondBuffer;
@@ -142,10 +144,11 @@ struct ExpectSaneLayoutInterface
142144
}
143145

144146
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
145-
const BufferizationOptions &options) const {
147+
const BufferizationOptions &options,
148+
BufferizationState &state) const {
146149
check::ExpectSaneOp saneOp = cast<check::ExpectSaneOp>(op);
147150
FailureOr<Value> maybeBuffer =
148-
getBuffer(rewriter, saneOp.getOperand(), options);
151+
getBuffer(rewriter, saneOp.getOperand(), options, state);
149152
if (failed(maybeBuffer)) {
150153
return failure();
151154
}

lib/TPP/Dialect/Perf/BufferizableOpInterfaceImpl.cpp

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,12 @@ struct SinkLayoutInterface
5555
}
5656

5757
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
58-
const BufferizationOptions &options) const {
58+
const BufferizationOptions &options,
59+
BufferizationState &state) const {
5960
auto sink = cast<perf::SinkOp>(op);
6061

61-
FailureOr<Value> srcBuffer = getBuffer(rewriter, sink.getInput(), options);
62+
FailureOr<Value> srcBuffer =
63+
getBuffer(rewriter, sink.getInput(), options, state);
6264
if (failed(srcBuffer))
6365
return failure();
6466

lib/TPP/Transforms/LowerPacksAndUnpacks.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
112112
forLoops);
113113
if (!fusedProducer)
114114
continue;
115-
rewriter.replaceOp(consumerPackOp, tilingResult->mergeResult.replacements);
115+
rewriter.replaceOp(consumerPackOp, tilingResult->replacements);
116116
}
117117

118118
// Tile packs.
@@ -124,7 +124,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
124124
rewriter, cast<TilingInterface>(packOp.getOperation()), tileSizes);
125125
if (failed(tilingResult))
126126
continue;
127-
rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements);
127+
rewriter.replaceOp(packOp, tilingResult->replacements);
128128
}
129129

130130
// Tile unpacks.
@@ -136,7 +136,7 @@ static void fuseOrTilePacks(RewriterBase &rewriter, FunctionOpInterface func) {
136136
rewriter, cast<TilingInterface>(unPackOp.getOperation()), tileSizes);
137137
if (failed(tilingResult))
138138
continue;
139-
rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements);
139+
rewriter.replaceOp(unPackOp, tilingResult->replacements);
140140
}
141141
}
142142

@@ -215,7 +215,7 @@ class LowerPacksAndUnPacks
215215
unpackTilingOptions);
216216
if (failed(tilingResult))
217217
return signalPassFailure();
218-
rewriter.replaceOp(unPackOp, tilingResult->mergeResult.replacements);
218+
rewriter.replaceOp(unPackOp, tilingResult->replacements);
219219
});
220220
getOperation()->walk([&](linalg::PackOp packOp) {
221221
SmallVector<int64_t> tiles(packOp.getSourceType().getRank(), 1);
@@ -226,7 +226,7 @@ class LowerPacksAndUnPacks
226226
packTilingOptions);
227227
if (failed(tilingResult))
228228
return signalPassFailure();
229-
rewriter.replaceOp(packOp, tilingResult->mergeResult.replacements);
229+
rewriter.replaceOp(packOp, tilingResult->replacements);
230230
});
231231
RewritePatternSet patterns(&getContext());
232232
patterns.add<linalg::DecomposeOuterUnitDimsUnPackOpPattern,

lib/TPP/Transforms/RewriteBatchMatmulToMatmul.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ struct RewriteBatchMatmulToMatmul
111111
tilingOpts);
112112
if (failed(tilingResult))
113113
return signalPassFailure();
114-
rewriter.replaceOp(batchMatmulOp, tilingResult->mergeResult.replacements);
114+
rewriter.replaceOp(batchMatmulOp, tilingResult->replacements);
115115
});
116116

117117
// Step2:

lib/TPP/Transforms/SplitReductionDim.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ struct SplitContractionReduction
8181
return rewriter.notifyMatchFailure(linalgOp,
8282
"failed to tile contraction");
8383

84-
rewriter.replaceOp(linalgOp, tilingResult->mergeResult.replacements);
84+
rewriter.replaceOp(linalgOp, tilingResult->replacements);
8585

8686
return success();
8787
}

lib/TPP/Transforms/VectorContractToAMX.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,8 +344,8 @@ struct VectorContractToAMXPattern
344344
return rewriter.notifyMatchFailure(
345345
op, "Accumulator defined by TransferReadOp");
346346

347-
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
348-
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex))
347+
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) ||
348+
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger))
349349
return rewriter.notifyMatchFailure(
350350
op, "Inputs are not whole tensor or subview");
351351

lib/TPP/Transforms/VectorContractToFMA.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,8 +174,8 @@ struct VectorContractToFMAPattern
174174
return failure();
175175

176176
// Make sure the inputs being read are whole tensor or subview.
177-
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
178-
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
177+
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) ||
178+
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger)) {
179179
return failure();
180180
}
181181

lib/TPP/Transforms/VectorContractToOuterproduct.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,8 +133,8 @@ struct VectorContractToOuterproductPattern
133133
return failure();
134134

135135
// Make sure the inputs being read are whole tensor or subview.
136-
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroIndex) ||
137-
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroIndex)) {
136+
if (!llvm::all_of(lhsDefiningOp.getIndices(), isZeroInteger) ||
137+
!llvm::all_of(rhsDefiningOp.getIndices(), isZeroInteger)) {
138138
return failure();
139139
}
140140

python/mlir/tpp/sched/bundles.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from typing import Optional, Sequence
22

3+
from mlir import ir
34
from mlir.dialects import transform
45
from .common import apply_registered_pass, match
56
from .utils import GpuBackend, PipelineInterrupt
@@ -67,7 +68,7 @@ def linalg_lowering(mod, /, *, skip_operations: Sequence[str] = (), **_config):
6768
func = apply_registered_pass(
6869
func,
6970
"convert-linalg-to-xsmm",
70-
options="skip-operations=" + ",".join(skip_operations),
71+
options={"skip-operations": ",".join(skip_operations)},
7172
)
7273
func = apply_registered_pass(func, "combine-xsmm-op-optimization")
7374
func = apply_registered_pass(func, "fold-xsmm-flags")
@@ -130,7 +131,7 @@ def low_level_parallel(
130131
# Run cleanup after LICM to allow CSE to eliminate common operations now
131132
# that they are hoisted out of loops.
132133
mod = cleanup(mod)
133-
options = "parallel-loop-tile-sizes=" + ",".join(map(str, parallel_task_grid))
134+
options = {"parallel-loop-tile-sizes": ",".join(map(str, parallel_task_grid))}
134135
mod = apply_registered_pass(mod, "scf-parallel-loop-tiling", options=options)
135136
return mod
136137

@@ -228,7 +229,7 @@ def default_tpp_passes(
228229
mod = linalg_lowering(mod, skip_operations=skip_ops, **config)
229230
if linalg_to_vector or force_linalg_to_vector:
230231
func = match(mod, ops={"func.func"})
231-
options = "registerTileShape=" + ",".join(map(str, register_blocking))
232+
options = {"registerTileShape": ",".join(map(str, register_blocking))}
232233
func = apply_registered_pass(func, "brgemm-linalg-tiling", options=options)
233234
func = apply_registered_pass(func, "loop-invariant-code-motion")
234235
apply_registered_pass(func, "vectorization-pass")
@@ -315,7 +316,7 @@ def default_pipeline(
315316
# #if defined(__x86_64__)
316317
# options.x86Vector = true;
317318
# #endif
318-
options = f"enable-amx={int(xsmm_utils.has_amx())}"
319+
options = {"enable-amx": int(xsmm_utils.has_amx())}
319320
mod = apply_registered_pass(mod, "convert-vector-to-llvm", options=options)
320321
mod = apply_registered_pass(mod, "finalize-memref-to-llvm")
321322
mod = apply_registered_pass(mod, "convert-scf-to-cf")
@@ -327,9 +328,8 @@ def default_pipeline(
327328
# gpu-to-llvm cannot be invoked from transform-interpreter as it
328329
# tries to load ... something while multi-threaded PassManager is running.
329330
mod = apply_registered_pass(mod, "gpu-to-llvm")
330-
mod = apply_registered_pass(
331-
mod, "gpu-module-to-binary", options="compilation-target=fatbin"
332-
)
331+
options = {"compilation-target": "fatbin"}
332+
mod = apply_registered_pass(mod, "gpu-module-to-binary", options=options)
333333
mod = apply_registered_pass(mod, "convert-math-to-llvm")
334334
if gpu_backend:
335335
mod = apply_registered_pass(mod, "async-to-async-runtime")

0 commit comments

Comments
 (0)