Skip to content

Commit 120e5c4

Browse files
Migrate the entire bufferization pipeline to one-shot bufferization (#1751)
**Context:** This work is based on #1027. Now that we have migrated all the individual dialects, we should migrate the entire bufferization pipeline. The `Quantum` dialect was migrated in #1686 . The `Catalyst` dialect was migrated in #1708 . The `Gradient` dialect was migrated in #1740 . See more context in #1027. Upstream changes in llvm were required for this bufferization update. As a result, the llvm version and mlir-hlo version were updated to ``` mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4 llvm=5f74671c85877e03622e8d308aee15ed73ccee7c ``` These are the versions tracked by jax 0.4.32. These are the earliest jax-tagged versions with complete upstream bufferization changes. **Related GitHub Issues:** [sc-71487] --------- Co-authored-by: Tzung-Han Juang <[email protected]>
1 parent 4b490db commit 120e5c4

26 files changed

+109
-113
lines changed

.dep-versions

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,15 @@
11
# Always update the version check in catalyst.__init__ when changing the JAX version.
2+
3+
#############
4+
# We track mlir submodule versions from jax 0.4.32 for now
5+
# These are the earliest versions with complete upstream bufferization changes
6+
# Versions are retrieved from
7+
# python3 .github/workflows/set_dep_versions.py 0.4.32
8+
#############
9+
210
jax=0.6.0
3-
mhlo=89a891c986650c33df76885f5620e0a92150d90f
4-
llvm=3a8316216807d64a586b971f51695e23883331f7
11+
mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4
12+
llvm=5f74671c85877e03622e8d308aee15ed73ccee7c
513
enzyme=v0.0.149
614

715
# Always remove custom PL/LQ versions before release.

.github/workflows/build-wheel-linux-arm64.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ jobs:
222222
-DCMAKE_CXX_VISIBILITY_PRESET=default \
223223
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"
224224
225-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-19
225+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
226226
227227
- name: Save Enzyme Build
228228
id: save-enzyme-build

.github/workflows/build-wheel-linux-x86_64.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ jobs:
245245
-DCMAKE_CXX_VISIBILITY_PRESET=default \
246246
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"
247247
248-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-19
248+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
249249
250250
- name: Save Enzyme Build
251251
id: save-enzyme-build

.github/workflows/build-wheel-macos-arm64.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ jobs:
218218
-DENZYME_STATIC_LIB=ON \
219219
-DCMAKE_CXX_VISIBILITY_PRESET=default
220220
221-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-19
221+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
222222
223223
- name: Save Enzyme Build
224224
id: save-enzyme-build

doc/releases/changelog-dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@
213213
[(#1686)](https://github.com/PennyLaneAI/catalyst/pull/1686)
214214
[(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)
215215
[(#1740)](https://github.com/PennyLaneAI/catalyst/pull/1740)
216+
[(#1751)](https://github.com/PennyLaneAI/catalyst/pull/1751)
216217

217218
* Redundant `OptionalAttr` is removed from `adjoint` argument in `QuantumOps.td` TableGen file
218219
[(#1746)](https://github.com/PennyLaneAI/catalyst/pull/1746)

frontend/catalyst/pipelines.py

Lines changed: 19 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -213,32 +213,33 @@ def get_quantum_compilation_stage(options: CompileOptions) -> List[str]:
213213
return list(filter(partial(is_not, None), quantum_compilation))
214214

215215

216-
def get_bufferization_stage(_options: CompileOptions) -> List[str]:
216+
def get_bufferization_stage(options: CompileOptions) -> List[str]:
217217
"""Returns the list of passes that performs bufferization"""
218+
219+
bufferization_options = """bufferize-function-boundaries
220+
allow-return-allocs-from-loops
221+
function-boundary-type-conversion=identity-layout-map
222+
unknown-type-conversion=identity-layout-map""".replace(
223+
"\n", " "
224+
)
225+
if options.async_qnodes:
226+
bufferization_options += " copy-before-write"
227+
218228
bufferization = [
219-
"one-shot-bufferize{dialect-filter=memref}",
220229
"inline",
221-
"gradient-preprocess",
222-
"one-shot-bufferize{dialect-filter=gradient unknown-type-conversion=identity-layout-map}",
223-
"scf-bufferize",
224230
"convert-tensor-to-linalg", # tensor.pad
225-
"convert-elementwise-to-linalg", # Must be run before --arith-bufferize
226-
"arith-bufferize",
227-
"empty-tensor-to-alloc-tensor",
228-
"func.func(bufferization-bufferize)",
229-
"func.func(tensor-bufferize)",
230-
# Catalyst dialect's bufferization must be run before --func.func(linalg-bufferize)
231-
"one-shot-bufferize{dialect-filter=catalyst unknown-type-conversion=identity-layout-map}",
232-
"func.func(linalg-bufferize)",
233-
"func.func(tensor-bufferize)",
234-
"one-shot-bufferize{dialect-filter=quantum}",
235-
"func-bufferize",
236-
"func.func(finalizing-bufferize)",
231+
"convert-elementwise-to-linalg", # Must be run before --one-shot-bufferize
232+
"gradient-preprocess",
233+
"eliminate-empty-tensors",
234+
####################
235+
"one-shot-bufferize{" + bufferization_options + "}",
236+
####################
237237
"canonicalize", # Remove dead memrefToTensorOp's
238238
"gradient-postprocess",
239239
# introduced during gradient-bufferize of callbacks
240240
"func.func(buffer-hoisting)",
241241
"func.func(buffer-loop-hoisting)",
242+
"func.func(promote-buffers-to-stack)",
242243
"func.func(buffer-deallocation)",
243244
"convert-arraylist-to-memref",
244245
"convert-bufferization-to-memref",
@@ -247,6 +248,7 @@ def get_bufferization_stage(_options: CompileOptions) -> List[str]:
247248
# "cse",
248249
"cp-global-memref",
249250
]
251+
250252
return bufferization
251253

252254

mlir/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ set(ALL_MHLO_PASSES
5151
HloToLinalgUtils
5252
MhloToLinalg
5353
MhloToStablehlo
54-
MhloQuantToIntConversion
5554
StablehloToMhlo
5655
)
5756

mlir/Makefile

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,10 @@ llvm:
8585

8686
# TODO: when updating LLVM, test to see if mlir/unittests/Bytecode/BytecodeTest.cpp:55 is passing
8787
# and remove filter. This tests fails on CI/CD not locally.
88-
LIT_FILTER_OUT="Bytecode|tosa-to-tensor" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS)
88+
# Note: the upstream lit test llvm-project/mlir/test/python/execution_engine.py requries
89+
# the python package `ml_dtypes`. We don't actually use the execution engine, so we skip the
90+
# test to reduce unnecessary dependencies.
91+
LIT_FILTER_OUT="Bytecode|tosa-to-tensor|execution_engine" cmake --build $(LLVM_BUILD_DIR) --target $(LLVM_TARGETS)
8992

9093
.PHONY: mhlo
9194
mhlo: TARGET_FILE := $(MK_DIR)/mlir-hlo/mhlo/transforms/CMakeLists.txt
@@ -130,7 +133,7 @@ enzyme:
130133
-DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \
131134
-DCMAKE_POLICY_DEFAULT_CMP0116=NEW
132135

133-
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-19
136+
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-20
134137

135138
.PHONY: plugin
136139
plugin:

mlir/lib/Catalyst/Transforms/DetectQNodes.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,7 @@ struct AddExceptionHandlingPass : impl::AddExceptionHandlingPassBase<AddExceptio
925925

926926
GreedyRewriteConfig config;
927927
config.strictMode = GreedyRewriteStrictness::ExistingOps;
928-
config.enableRegionSimplification = false;
928+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
929929

930930
if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns1), config))) {
931931
signalPassFailure();

mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,7 +383,7 @@ struct AnnotateWithFullyQualifiedNamePass
383383
// Do not fold to save in compile time.
384384
GreedyRewriteConfig config;
385385
config.strictMode = GreedyRewriteStrictness::ExistingOps;
386-
config.enableRegionSimplification = false;
386+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
387387

388388
RewritePatternSet annotate(context);
389389
auto root = getOperation();
@@ -409,7 +409,7 @@ struct InlineNestedSymbolTablePass : PassWrapper<InlineNestedSymbolTablePass, Op
409409

410410
GreedyRewriteConfig config;
411411
config.strictMode = GreedyRewriteStrictness::ExistingOps;
412-
config.enableRegionSimplification = false;
412+
config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled;
413413

414414
RewritePatternSet renameFunctions(context);
415415

0 commit comments

Comments
 (0)