Skip to content

Commit b720507

Browse files
committed
As part of the llvm update, the old --buffer-deallocation pass is removed.
Intended replacement is --buffer-deallocation-pipeline. [mlir][bufferization] Remove buffer-deallocation pass llvm/llvm-project#126366, https://discourse.llvm.org/t/psa-bufferization-new-buffer-deallocation-pipeline/73375. However, we encountered a few issues. See detailed discord discussion at https://discord.com/channels/636084430946959380/642426447167881246/1376919538301403276 The heart of the new deallocation pass, --ownership-based-buffer-deallocation, blanket-ly fails for ops with unknown memory effects: https://github.com/llvm/llvm-project/blob/da4958ae2b384c2a027cf20c67b7e211d39fcbfe/mlir/lib/Dialect/Bufferization/Transforms/OwnershipBasedBufferDeallocation.cpp#L523 To solve this issue, the suggestion was to add memory effects to custom operations. The migration guide also suggests that bufferizable ops no longer implement the bufferizesToAllocation method, so we remove them. This was supposed to be done alongside the llvm update in #1752; However, soon it became clear that this migration to the new buffer deallocation is very complicated, and should be its own story. The llvm update in #1752 thus did not finish this migration. This PR records the work that was already done on this back in #1752.
1 parent 1728ec1 commit b720507

File tree

83 files changed

+2029
-404
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

83 files changed

+2029
-404
lines changed

.dep-versions

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,10 @@
11
# Always update the version check in catalyst.__init__ when changing the JAX version.
2-
3-
#############
4-
# We track mlir submodule versions from jax 0.4.32 for now
5-
# These are the earliest versions with complete upstream bufferization changes
6-
# Versions are retrieved from
7-
# python3 .github/workflows/set_dep_versions.py 0.4.32
8-
#############
9-
2+
# To update JAX version alongside compatible dependency tags, run the following script:
3+
# python3 .github/workflows/set_dep_versions.py {JAX_version}
104
jax=0.6.0
11-
mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4
12-
llvm=5f74671c85877e03622e8d308aee15ed73ccee7c
13-
enzyme=v0.0.149
5+
mhlo=617a9361d186199480c080c9e8c474a5e30c22d1
6+
llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158
7+
enzyme=v0.0.180
148

159
# Always remove custom PL/LQ versions before release.
1610

.github/workflows/build-wheel-linux-arm64.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ jobs:
222222
-DCMAKE_CXX_VISIBILITY_PRESET=default \
223223
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"
224224
225-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
225+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21
226226
227227
- name: Save Enzyme Build
228228
id: save-enzyme-build

.github/workflows/build-wheel-linux-x86_64.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,7 @@ jobs:
245245
-DCMAKE_CXX_VISIBILITY_PRESET=default \
246246
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"
247247
248-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
248+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21
249249
250250
- name: Save Enzyme Build
251251
id: save-enzyme-build

.github/workflows/build-wheel-macos-arm64.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,7 @@ jobs:
218218
-DENZYME_STATIC_LIB=ON \
219219
-DCMAKE_CXX_VISIBILITY_PRESET=default
220220
221-
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
221+
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21
222222
223223
- name: Save Enzyme Build
224224
id: save-enzyme-build

.github/workflows/check-catalyst.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ jobs:
146146
sudo apt-get update
147147
sudo apt-get install -y python3 python3-pip cmake ninja-build clang lld
148148
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
149-
python3 -m pip install numpy pybind11
149+
python3 -m pip install numpy pybind11 nanobind==2.4
150150
151151
- name: Build LLVM
152152
if: steps.cache-llvm-build.outputs.cache-hit != 'true'

.github/workflows/set_dep_versions.py

Lines changed: 14 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,11 @@
3232
assert os.path.isfile(dep_versions_path)
3333
assert os.path.isfile(catalyst_init_path)
3434

35-
url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/WORKSPACE"
35+
url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/WORKSPACE"
3636
response = requests.get(url)
3737
match = re.search(r'strip_prefix = "xla-([a-zA-Z0-9]*)"', response.text)
3838
if not match:
39-
url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/third_party/xla/workspace.bzl"
39+
url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/third_party/xla/workspace.bzl"
4040
response = requests.get(url)
4141
match = re.search(r'XLA_COMMIT = "([a-zA-Z0-9]*)"', response.text)
4242
xla_commit = match.group(1)
@@ -67,21 +67,16 @@
6767
response = requests.get(url).json()
6868
hlo_commit = response["items"][0]["sha"]
6969

70-
existing_text = open(dep_versions_path, "r", encoding="UTF-8").read()
71-
match = re.search(r"enzyme=([a-zA-Z0-9]*)", existing_text)
72-
enzyme_commit = match.group(1)
73-
74-
with open(dep_versions_path, "w", encoding="UTF-8") as f:
75-
f.write(
76-
f"""\
77-
jax={jax_version}
78-
mhlo={hlo_commit}
79-
llvm={llvm_commit}
80-
enzyme={enzyme_commit}
81-
"""
82-
)
83-
8470
quote = '"'
85-
cmd = f"sed -i 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}"
86-
res = os.system(cmd)
87-
assert res == 0
71+
# Update each version using sed
72+
cmds = [
73+
f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}",
74+
f"sed -i '' 's/^mhlo=.*/mhlo={hlo_commit}/' {dep_versions_path}",
75+
f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}",
76+
# Update jaxlib version in __init__.py
77+
rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}",
78+
]
79+
80+
for cmd in cmds:
81+
res = os.system(cmd)
82+
assert res == 0

doc/dev/transforms.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -512,15 +512,15 @@ and other function operations, which themselves can contain other operations, an
512512
quantumPatterns.add<QubitUnitaryFusion>(ctx);
513513
514514
// Apply patterns in an iterative and greedy manner.
515-
if (failed(applyPatternsAndFoldGreedily(op, std::move(quantumPatterns)))) {
515+
if (failed(applyPatternsGreedily(op, std::move(quantumPatterns)))) {
516516
return signalPassFailure();
517517
}
518518
}
519519
};
520520
521521
To apply patterns we need a `pattern applicator <https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers>`_.
522522
There a few in MLIR but typically you can just use the greedy pattern rewrite driver
523-
(``applyPatternsAndFoldGreedily``), which will iterative over the IR and apply patterns until a
523+
(``applyPatternsGreedily``), which will iterative over the IR and apply patterns until a
524524
fixed point is reached.
525525

526526
.. note::

doc/releases/changelog-dev.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@
175175
[(#1671)](https://github.com/PennyLaneAI/catalyst/pull/1671)
176176
[(#1681)](https://github.com/PennyLaneAI/catalyst/pull/1681)
177177

178+
* (Compiler developers only) The version of LLVM, MHLO and Enzyme used by Catalyst is
179+
updated to track those in jax 0.6.0.
180+
[(#1752)](https://github.com/PennyLaneAI/catalyst/pull/1752)
181+
182+
The LLVM version is updated to commit 179d30f8c3fddd3c85056fd2b8e877a4a8513158.
183+
The MHLO version is updated to commit 617a9361d186199480c080c9e8c474a5e30c22d1.
184+
The Enzyme version is updated to v0.0.180.
185+
178186
* The clang-format and clang-tidy versions used by Catalyst have been updated to v20.
179187
[(#1721)](https://github.com/PennyLaneAI/catalyst/pull/1721)
180188

@@ -257,7 +265,7 @@
257265
* Improved the definition of `YieldOp` in the quantum dialect by removing `AnyTypeOf`
258266
[(#1696)](https://github.com/PennyLaneAI/catalyst/pull/1696)
259267

260-
* The assembly format of `MeasureOp` in the `Quantum` dialect and `MeasureInBasisOp` in the `MBQC` dialect now contains the `postselect` attribute.
268+
* The assembly format of `MeasureOp` in the `Quantum` dialect and `MeasureInBasisOp` in the `MBQC` dialect now contains the `postselect` attribute.
261269
[(#1732)](https://github.com/PennyLaneAI/catalyst/pull/1732)
262270

263271
* The bufferization of custom catalyst dialects has been migrated to the new one-shot

frontend/catalyst/pipelines.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -240,7 +240,7 @@ def get_bufferization_stage(options: CompileOptions) -> List[str]:
240240
"func.func(buffer-hoisting)",
241241
"func.func(buffer-loop-hoisting)",
242242
"func.func(promote-buffers-to-stack)",
243-
"func.func(buffer-deallocation)",
243+
"buffer-deallocation-pipeline",
244244
"convert-arraylist-to-memref",
245245
"convert-bufferization-to-memref",
246246
"canonicalize", # Must be after convert-bufferization-to-memref

mlir/CMakeLists.txt

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,6 @@ set(ALL_MHLO_PASSES
4747
StablehloPasses
4848
MhloToArithmeticConversion
4949
MhloToMemrefConversion
50-
MhloToStandard
5150
HloToLinalgUtils
5251
MhloToLinalg
5352
MhloToStablehlo

0 commit comments

Comments
 (0)