Skip to content

Commit 38469e4

Browse files
rniczhdime10
andauthored
Update LLVM version, 2025 Q4 (#2122)
**Context:** Update llvm, stablehlo and enzyme, 2025 Q4 The latest pair of good versions, indicated by stablehlo, is openxla/stablehlo@0a4440a ``` stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 ``` For Enzyme, we go to the latest release https://github.com/EnzymeAD/Enzyme/releases/tag/v0.0.203 ``` enzyme=v0.0.203 ``` with commit `476c8e3193a8577ba24ff845ae2294109225f83a` **Description of the Change:** - `llvm-project/mlir/include/mlir/InitAllPasses.h` header no longer includes individual pass headers. To incorporate specific passes, you now need to include either `mlir/Conversion/Passes.h` or the header for each dialect pass. This change stems from the upstream cleanup of `mlir/InitAllDialects.h`, which previously handled these includes. Following the decision from the owner (as seen in this pull request comment: https://github.com/ftynse/water/pull/16/files#r2259526343), we will specifically include each required header, faster compilation (no unnecessary headers). For further context, refer to this PR: llvm/llvm-project#151150 - To bufferize custom tensors into custom buffers. The upstream PR llvm/llvm-project#144867 changed the the return type of `bufferization::detail::defaultGetBufferType` to be `BufferLikeType` instead of `BaseMemRefType`. We aligned the behaviour with the upstream PR, return the BufferLikeType from our `getBufferType()` implementation as well. - `elemwise_unary` and `elemwise_binary` of `linalg` have been deprecated from the upstream, replaced with `elementwise`. llvm/llvm-project#147082. Follow this PR, using `linalg.add` directly. register_all_dialects/passes/extension` should now be configured as a static library in CMake, refer to this PR: llvm/llvm-project#151150 - An opt-in feature (`{op}::create` method) be added to support the same behaviour of `rewriter.create` method. But it provides a meaningful message when using this new opt-in feature compared to `rewrite.create`. Since the new generated `create` function will call the `build` function, it requires us to define the instance of the builder (`GraidentOps.td`). Refer to the PR: llvm/llvm-project#147168. **Take `linalg.abs` as an example:** ```cpp void AbsOp::build(::mlir::OpBuilder &odsBuilder, ::mlir::OperationState &odsState, ValueRange inputs, ValueRange outputs, ArrayRef<NamedAttribute> attributes) { buildStructuredOp(odsBuilder, odsState, std::nullopt, inputs, outputs, attributes, AbsOp::getRegionBuilder()); } AbsOp AbsOp::create(::mlir::OpBuilder &builder, ::mlir::Location location, ValueRange inputs, ValueRange outputs, ArrayRef<NamedAttribute> attributes) { ::mlir::OperationState __state__(location, getOperationName()); build(builder, __state__, inputs, outputs, attributes); auto __res__ = ::llvm::dyn_cast<AbsOp>(builder.create(__state__)); assert(__res__ && "builder didn't return the right type"); return __res__; } AbsOp AbsOp::create(::mlir::ImplicitLocOpBuilder &builder, ValueRange inputs, ValueRange outputs, ArrayRef<NamedAttribute> attributes) { return create(builder, builder.getLoc(), inputs, outputs, attributes); } ``` - Changed `EnzymeStatic-21` to `EnzymeStatic-22` - Mock `_ods_cext.globals.register_traceback_file_exclusion` due to API conflicts between Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version we used has not yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed once JAX updates to a compatible MLIR version. llvm/llvm-project#151246 TODO: - [x] Update .dep_version - [ ] Change all `rewrite.create` to `{op}::create` Even it’s compatible to the current `rewrite.create` method, we don’t need to change all `rewrite.create` to `{op}::create` right away, but this new feature provides more friendly development experience. Just for an example: llvm/llvm-project#147311 **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** [sc-100911] [sc-100912] --------- Co-authored-by: David Ittah <[email protected]>
1 parent e1274ed commit 38469e4

File tree

18 files changed

+307
-108
lines changed

18 files changed

+307
-108
lines changed

.dep-versions

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
# To update JAX version alongside compatible dependency tags, run the following script:
33
# python3 .github/workflows/set_dep_versions.py {JAX_version}
44
jax=0.6.2
5-
stablehlo=69d6dae46e1c7de36e6e6973654754f05353cba5
6-
llvm=f8cb7987c64dcffb72414a40560055cb717dbf74
7-
enzyme=v0.0.186
5+
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
6+
llvm=113f01aa82d055410f22a9d03b3468fa68600589
7+
enzyme=v0.0.203
88

99
# Always remove custom PL/LQ versions before release.
1010

frontend/catalyst/jax_extras/patches.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,29 @@
4343
)
4444

4545

46+
def mock_attributes(obj, attrs: dict[str, any]):
47+
"""Mock the attribute of an object by returning a wrapper.
48+
49+
Args:
50+
obj: The object to mock the attributes of.
51+
attrs: A dictionary of attributes to mock.
52+
Example: {"attribute_name": attribute_value}
53+
"""
54+
55+
class MockAttributeWrapper:
56+
"""Wrapper to mock the attribute of an object."""
57+
58+
def __init__(self, original):
59+
self.original = original
60+
61+
def __getattr__(self, name):
62+
if name in attrs:
63+
return attrs[name]
64+
return getattr(self.original, name)
65+
66+
return MockAttributeWrapper(obj)
67+
68+
4669
def _drop_unused_vars2(jaxpr, constvals):
4770
"""
4871
A patch to not drop unused vars during classical tracing of control flow.

frontend/catalyst/jax_primitives.py

Lines changed: 80 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
from jax.interpreters import mlir
4040
from jax.tree_util import PyTreeDef, tree_unflatten
4141
from jaxlib.hlo_helpers import shape_dtype_to_ir_type
42+
from jaxlib.mlir._mlir_libs import _mlir as _ods_cext
4243
from jaxlib.mlir.dialects.arith import (
4344
AddIOp,
4445
CeilDivSIOp,
@@ -52,54 +53,85 @@
5253
from jaxlib.mlir.dialects.scf import ConditionOp, ForOp, IfOp, WhileOp, YieldOp
5354
from jaxlib.mlir.dialects.stablehlo import ConstantOp as StableHLOConstantOp
5455
from jaxlib.mlir.dialects.stablehlo import ConvertOp as StableHLOConvertOp
55-
from mlir_quantum.dialects.catalyst import (
56-
AssertionOp,
57-
CallbackCallOp,
58-
CallbackOp,
59-
PrintOp,
60-
)
61-
from mlir_quantum.dialects.gradient import (
62-
CustomGradOp,
63-
ForwardOp,
64-
GradOp,
65-
JVPOp,
66-
ReverseOp,
67-
ValueAndGradOp,
68-
VJPOp,
69-
)
70-
from mlir_quantum.dialects.mbqc import MeasureInBasisOp
71-
from mlir_quantum.dialects.mitigation import ZneOp
72-
from mlir_quantum.dialects.quantum import (
73-
AdjointOp,
74-
AllocOp,
75-
ComputationalBasisOp,
76-
CountsOp,
77-
CustomOp,
78-
DeallocOp,
79-
DeallocQubitOp,
80-
DeviceInitOp,
81-
DeviceReleaseOp,
82-
ExpvalOp,
83-
ExtractOp,
84-
GlobalPhaseOp,
85-
HamiltonianOp,
86-
HermitianOp,
87-
InsertOp,
88-
MeasureOp,
89-
MultiRZOp,
90-
NamedObsOp,
91-
NumQubitsOp,
92-
PCPhaseOp,
93-
ProbsOp,
94-
QubitUnitaryOp,
95-
SampleOp,
96-
SetBasisStateOp,
97-
SetStateOp,
98-
StateOp,
99-
TensorOp,
100-
VarianceOp,
101-
)
102-
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
56+
57+
# TODO: remove after jax v0.7.2 upgrade
58+
# Mock _ods_cext.globals.register_traceback_file_exclusion due to API conflicts between
59+
# Catalyst's MLIR version and the MLIR version used by JAX. The current JAX version has not
60+
# yet updated to the latest MLIR, causing compatibility issues. This workaround will be removed
61+
# once JAX updates to a compatible MLIR version
62+
# pylint: disable=ungrouped-imports
63+
from catalyst.jax_extras.patches import mock_attributes
64+
from catalyst.utils.patching import Patcher
65+
66+
with Patcher(
67+
(
68+
_ods_cext,
69+
"globals",
70+
mock_attributes(
71+
# pylint: disable=c-extension-no-member
72+
_ods_cext.globals,
73+
{"register_traceback_file_exclusion": lambda x: None},
74+
),
75+
),
76+
):
77+
from mlir_quantum.dialects.catalyst import (
78+
AssertionOp,
79+
CallbackCallOp,
80+
CallbackOp,
81+
PrintOp,
82+
)
83+
from mlir_quantum.dialects.gradient import (
84+
CustomGradOp,
85+
ForwardOp,
86+
GradOp,
87+
JVPOp,
88+
ReverseOp,
89+
ValueAndGradOp,
90+
VJPOp,
91+
)
92+
from mlir_quantum.dialects.mbqc import MeasureInBasisOp
93+
from mlir_quantum.dialects.mitigation import ZneOp
94+
from mlir_quantum.dialects.quantum import (
95+
AdjointOp,
96+
AllocOp,
97+
ComputationalBasisOp,
98+
CountsOp,
99+
CustomOp,
100+
DeallocOp,
101+
DeallocQubitOp,
102+
DeviceInitOp,
103+
DeviceReleaseOp,
104+
ExpvalOp,
105+
ExtractOp,
106+
GlobalPhaseOp,
107+
HamiltonianOp,
108+
HermitianOp,
109+
InsertOp,
110+
MeasureOp,
111+
MultiRZOp,
112+
NamedObsOp,
113+
NumQubitsOp,
114+
PCPhaseOp,
115+
ProbsOp,
116+
QubitUnitaryOp,
117+
SampleOp,
118+
SetBasisStateOp,
119+
SetStateOp,
120+
StateOp,
121+
TensorOp,
122+
VarianceOp,
123+
)
124+
from mlir_quantum.dialects.quantum import YieldOp as QYieldOp
125+
from catalyst.jax_primitives_utils import (
126+
cache,
127+
create_call_op,
128+
get_cached,
129+
get_call_jaxpr,
130+
get_symbolref,
131+
lower_callable,
132+
lower_jaxpr,
133+
)
134+
103135
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
104136

105137
from catalyst.compiler import get_lib_path
@@ -111,19 +143,9 @@
111143
infer_output_type_jaxpr,
112144
while_loop_expansion_strategy,
113145
)
114-
from catalyst.jax_primitives_utils import (
115-
cache,
116-
create_call_op,
117-
get_cached,
118-
get_call_jaxpr,
119-
get_symbolref,
120-
lower_callable,
121-
lower_jaxpr,
122-
)
123146
from catalyst.utils.calculate_grad_shape import Signature, calculate_grad_shape
124147
from catalyst.utils.exceptions import CompileError
125148
from catalyst.utils.extra_bindings import FromElementsOp, TensorExtractOp
126-
from catalyst.utils.patching import Patcher
127149
from catalyst.utils.types import convert_shaped_arrays_to_tensors
128150

129151
# pylint: disable=unused-argument,too-many-lines,too-many-statements,protected-access
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# Copyright 2025 Xanadu Quantum Technologies Inc.
2+
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""Tests for the jax_extras.patches module"""
15+
16+
from catalyst.jax_extras.patches import mock_attributes
17+
18+
19+
# pylint: disable=missing-class-docstring,missing-function-docstring
20+
class TestMockAttributes:
21+
"""Test the mock_attributes function and MockAttributeWrapper class."""
22+
23+
def test_mock_attributes_returns_mocked_value(self):
24+
"""Test that accessing a mocked attribute returns the mocked value."""
25+
26+
class DummyClass:
27+
def __init__(self):
28+
self.original_attr = "original"
29+
30+
obj = DummyClass()
31+
mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"})
32+
33+
# Access the mocked attribute - this should come from the attrs dict
34+
assert mocked.mocked_attr == "mocked_value"
35+
36+
def test_mock_attributes_returns_original_value(self):
37+
"""Test that accessing an unmocked attribute returns the original value."""
38+
39+
class DummyClass:
40+
def __init__(self):
41+
self.original_attr = "original"
42+
43+
obj = DummyClass()
44+
mocked = mock_attributes(obj, {"mocked_attr": "mocked_value"})
45+
46+
# Access the original attribute - this should come from the original object
47+
# This tests the else branch in __getattr__
48+
assert mocked.original_attr == "original"
49+
50+
def test_mock_attributes_with_methods(self):
51+
"""Test that calling original methods works through the wrapper."""
52+
53+
class DummyClass:
54+
def __init__(self):
55+
self.value = 42
56+
57+
def get_value(self):
58+
return self.value
59+
60+
obj = DummyClass()
61+
62+
def mocked_method():
63+
return "mocked"
64+
65+
mocked = mock_attributes(obj, {"mocked_method": mocked_method})
66+
67+
# Access the mocked method
68+
assert mocked.mocked_method() == "mocked"
69+
70+
# Access the original method - tests the getattr fallback
71+
assert mocked.get_value() == 42
72+
73+
def test_mock_attributes_with_callable(self):
74+
"""Test mocking with callable attributes like lambda functions."""
75+
76+
class DummyClass:
77+
def __init__(self):
78+
self.original_func = lambda x: x * 2
79+
80+
obj = DummyClass()
81+
mocked = mock_attributes(obj, {"new_func": lambda x: x * 3})
82+
83+
# Access the mocked callable
84+
assert mocked.new_func(5) == 15
85+
86+
# Access the original callable - tests the getattr fallback
87+
assert mocked.original_func(5) == 10
88+
89+
def test_mock_attributes_override_existing(self):
90+
"""Test that mocking can override existing attributes."""
91+
92+
class DummyClass:
93+
def __init__(self):
94+
self.attr = "original"
95+
96+
obj = DummyClass()
97+
mocked = mock_attributes(obj, {"attr": "overridden"})
98+
99+
# The mocked value should take precedence
100+
assert mocked.attr == "overridden"
101+
102+
def test_mock_attributes_stores_original(self):
103+
"""Test that the original object is accessible through the wrapper."""
104+
105+
class DummyClass:
106+
def __init__(self):
107+
self.value = 100
108+
109+
obj = DummyClass()
110+
mocked = mock_attributes(obj, {})
111+
112+
# The wrapper should store the original object
113+
assert mocked.original is obj
114+
assert mocked.original.value == 100

mlir/Enzyme

Submodule Enzyme updated 153 files

mlir/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ enzyme:
138138
-DCMAKE_CXX_VISIBILITY_PRESET=$(SYMBOL_VISIBILITY) \
139139
-DCMAKE_POLICY_DEFAULT_CMP0116=NEW
140140

141-
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-21
141+
cmake --build $(ENZYME_BUILD_DIR) --target EnzymeStatic-22
142142

143143
.PHONY: plugin
144144
plugin:

mlir/include/Catalyst/IR/CatalystOps.td

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -138,8 +138,17 @@ def CallbackOp : Catalyst_Op<"callback",
138138

139139
let builders = [OpBuilder<(ins
140140
"mlir::StringRef":$name, "mlir::FunctionType":$type,
141-
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs)
142-
>];
141+
CArg<"mlir::ArrayRef<mlir::NamedAttribute>", "{}">:$attrs), [{
142+
$_state.addAttribute(mlir::SymbolTable::getSymbolAttrName(),
143+
$_builder.getStringAttr(name));
144+
$_state.addAttribute("function_type", mlir::TypeAttr::get(type));
145+
$_state.addAttribute("id", $_builder.getI64IntegerAttr(0));
146+
$_state.addAttribute("argc", $_builder.getI64IntegerAttr(type.getNumInputs()));
147+
$_state.addAttribute("resc", $_builder.getI64IntegerAttr(type.getNumResults()));
148+
$_state.attributes.append(attrs.begin(), attrs.end());
149+
$_state.addRegion();
150+
}]>
151+
];
143152

144153
let extraClassDeclaration = [{
145154
//===------------------------------------------------------------------===//

0 commit comments

Comments
 (0)