diff --git a/.dep-versions b/.dep-versions index 33496f09da..fd471b166c 100644 --- a/.dep-versions +++ b/.dep-versions @@ -1,7 +1,7 @@ # Always update the version check in catalyst.__init__ when changing the JAX version. # To update JAX version alongside compatible dependency tags, run the following script: # python3 .github/workflows/set_dep_versions.py {JAX_version} -jax=0.6.2 +jax=cd1b9520bd2e1b0dd451281f06e98f030347aa8b stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d llvm=113f01aa82d055410f22a9d03b3468fa68600589 enzyme=v0.0.203 diff --git a/frontend/catalyst/__init__.py b/frontend/catalyst/__init__.py index eeaf6aa1be..d5645add8c 100644 --- a/frontend/catalyst/__init__.py +++ b/frontend/catalyst/__init__.py @@ -68,6 +68,10 @@ sys.modules["mlir_quantum.ir"] = __import__("jaxlib.mlir.ir").mlir.ir sys.modules["mlir_quantum._mlir_libs"] = __import__("jaxlib.mlir._mlir_libs").mlir._mlir_libs +from catalyst.jax_extras.patches import patch_primitives + +patch_primitives() + from catalyst import debug, logging, passes from catalyst.api_extensions import * from catalyst.api_extensions import __all__ as _api_extension_list diff --git a/frontend/catalyst/jax_extras/lowering.py b/frontend/catalyst/jax_extras/lowering.py index 7dc1382593..0b8e6c30b0 100644 --- a/frontend/catalyst/jax_extras/lowering.py +++ b/frontend/catalyst/jax_extras/lowering.py @@ -23,8 +23,6 @@ from jax._src.effects import ordered_effects as jax_ordered_effects from jax._src.interpreters.mlir import _module_name_regex from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext -from jax._src.source_info_util import new_name_stack -from jax._src.util import wrap_name from jax.extend.core import ClosedJaxpr from jax.interpreters.mlir import ( AxisContext, @@ -72,7 +70,6 @@ def jaxpr_to_mlir(func_name, jaxpr): nrep = jaxpr_replicas(jaxpr) effects = jax_ordered_effects.filter_in(jaxpr.effects) axis_context = ReplicaAxisContext(AxisEnv(nrep, (), ())) - name_stack = new_name_stack(wrap_name("ok", "jit")) module, context = custom_lower_jaxpr_to_module( func_name="jit_" + func_name, module_name=func_name, @@ -80,7 +77,6 @@ def jaxpr_to_mlir(func_name, jaxpr): effects=effects, platform="cpu", axis_context=axis_context, - name_stack=name_stack, ) return module, context @@ -95,7 +91,6 @@ def custom_lower_jaxpr_to_module( effects, platform: str, axis_context: AxisContext, - name_stack, replicated_args=None, arg_shardings=None, result_shardings=None, @@ -142,27 +137,32 @@ def custom_lower_jaxpr_to_module( # XLA computation preserves the module name. module_name = _module_name_regex.sub("_", module_name) ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name) + + # Use main_function=False to preserve the function name (e.g., "jit_func") + # instead of renaming it to "main" lower_jaxpr_to_fun( ctx, func_name, jaxpr, effects, - public=True, + main_function=False, replicated_args=replicated_args, arg_shardings=arg_shardings, result_shardings=result_shardings, - name_stack=name_stack, ) + # Set the entry point function visibility to public and other functions to internal worklist = [*ctx.module.body.operations] while worklist: op = worklist.pop() func_name = str(op.name) - is_entry_point = func_name.startswith('"jit_') - if is_entry_point: - continue if isinstance(op, FuncOp): - op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage") + if func_name.startswith('"jit_'): + # Keep entry point functions public + op.attributes["sym_visibility"] = ir.StringAttr.get("public") + else: + # Set non-entry functions to internal linkage + op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage") if isinstance(op, ModuleOp): worklist += [*op.body.operations] diff --git a/frontend/catalyst/jax_extras/patches.py b/frontend/catalyst/jax_extras/patches.py index ba32f498b9..5c99e83ba0 100644 --- a/frontend/catalyst/jax_extras/patches.py +++ b/frontend/catalyst/jax_extras/patches.py @@ -40,6 +40,7 @@ "_no_clean_up_dead_vars", "_gather_shape_rule_dynamic", "gather2_p", + "patch_primitives", ) @@ -208,3 +209,75 @@ def _gather_shape_rule_dynamic( sharding_rule=_gather_sharding_rule, vma_rule=partial(standard_vma_rule, "gather"), ) + + +# TODO: Remove this patch when JAX/PennyLane are updated to use the new JAX 0.7+ API. +def patch_primitives(): + """Patch PennyLane/JAX primitives to make them compatible with JAX 0.7+. + + JAX 0.7+ requires all primitive parameters to be hashable, but PennyLane + passes **lists** for some parameters like control_values, jaxpr_branches, etc. + This patch wraps the bind method to convert lists to tuples to make them hashable. + """ + + def make_hashable(value): + """Recursively convert lists to tuples to make them hashable.""" + if isinstance(value, list): + return tuple(make_hashable(item) for item in value) + return value + + try: + from pennylane.capture.primitives import ctrl_transform_prim + from pennylane.capture.primitives import cond_prim + from pennylane.ops.op_math.controlled import Controlled + + original_ctrl_bind = ctrl_transform_prim.bind + original_cond_bind = cond_prim.bind + original_controlled_bind = Controlled._primitive.bind + + def patched_ctrl_bind(*args, **kwargs): + # Convert control_values from list to tuple if present + if "control_values" in kwargs: + kwargs["control_values"] = make_hashable(kwargs["control_values"]) + return original_ctrl_bind(*args, **kwargs) + + def patched_cond_bind(*args, **kwargs): + # Convert list parameters to tuples + if "jaxpr_branches" in kwargs: + kwargs["jaxpr_branches"] = make_hashable(kwargs["jaxpr_branches"]) + if "consts_slices" in kwargs: + kwargs["consts_slices"] = make_hashable(kwargs["consts_slices"]) + return original_cond_bind(*args, **kwargs) + + def patched_controlled_bind(*args, **kwargs): + # Convert control_values from list to tuple if present + if "control_values" in kwargs: + kwargs["control_values"] = make_hashable(kwargs["control_values"]) + return original_controlled_bind(*args, **kwargs) + + # Replace the bind method + ctrl_transform_prim.bind = patched_ctrl_bind + cond_prim.bind = patched_cond_bind + Controlled._primitive.bind = patched_controlled_bind + + except ImportError: + pass + + # patch DynamicJaxprTrace members: makevar and getvar + try: + from jax._src.interpreters import partial_eval as pe + + def patched_makevar(self, tracer): + assert tracer.val is None, "a jaxpr variable must be created only once per tracer" + tracer.val = self.frame.newvar(tracer.aval) + return tracer.val + + def patched_getvar(self, tracer): # pylint: disable=unused-argument + if var := tracer.val: + return var + raise jax.core.escaped_tracer_error(tracer) + + pe.DynamicJaxprTrace.makevar = patched_makevar + pe.DynamicJaxprTrace.getvar = patched_getvar + except ImportError: + pass diff --git a/frontend/catalyst/jax_extras/tracing.py b/frontend/catalyst/jax_extras/tracing.py index 12c25e1304..beb66e102a 100644 --- a/frontend/catalyst/jax_extras/tracing.py +++ b/frontend/catalyst/jax_extras/tracing.py @@ -427,9 +427,8 @@ def trace_to_jaxpr( def new_inner_tracer(trace: DynamicJaxprTrace, aval) -> DynamicJaxprTracer: """Create a JAX tracer tracing an abstract value ``aval`, without specifying its source primitive.""" - dt = DynamicJaxprTracer(trace, aval, current_source_info()) - trace.frame.tracers.append(dt) - trace.frame.tracer_to_var[id(dt)] = trace.frame.newvar(aval) + atom = trace.frame.newvar(aval) + dt = DynamicJaxprTracer(trace, aval, atom, line_info=current_source_info()) return dt @@ -907,12 +906,11 @@ def bind(self, *args, **params): # `abstract_eval` returned `out_type` calculated for empty constants. [], tracers, - maker=lambda a: DynamicJaxprTracer(trace, a, source_info), + maker=lambda aval: new_inner_tracer(trace, aval), ) - invars = map(trace.getvar, tracers) - outvars = map(trace.makevar, out_tracers) - + invars = map(lambda t: t.val, tracers) + outvars = map(lambda t: t.val, out_tracers) eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info) trace.frame.add_eqn(eqn) return out_tracers if self.multiple_results else out_tracers.pop() diff --git a/frontend/catalyst/jax_primitives.py b/frontend/catalyst/jax_primitives.py index fec8e41684..ef2404d99c 100644 --- a/frontend/catalyst/jax_primitives.py +++ b/frontend/catalyst/jax_primitives.py @@ -686,7 +686,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params): flat_output_types, ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(args_and_consts), + mlir.flatten_ir_values(args_and_consts), diffArgIndices=diffArgIndices, finiteDiffParam=finiteDiffParam, ).results @@ -763,7 +763,7 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params): gradient_result_types, ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(func_args), + mlir.flatten_ir_values(func_args), diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results @@ -816,8 +816,8 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params): flat_output_types[len(flat_output_types) // 2 :], ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(func_args), - mlir.flatten_lowering_ir_args(tang_args), + mlir.flatten_ir_values(func_args), + mlir.flatten_ir_values(tang_args), diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results @@ -866,8 +866,8 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params): vjp_result_types, ir.StringAttr.get(method), symbol_ref, - mlir.flatten_lowering_ir_args(func_args), - mlir.flatten_lowering_ir_args(cotang_args), + mlir.flatten_ir_values(func_args), + mlir.flatten_ir_values(cotang_args), diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums), finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None, ).results @@ -930,7 +930,7 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn): return ZneOp( flat_output_types, symbol_ref, - mlir.flatten_lowering_ir_args(args_and_consts), + mlir.flatten_ir_values(args_and_consts), _folding_attribute(ctx, folding), num_folds, ).results @@ -1686,17 +1686,18 @@ def custom_measurement_staging_rule( else: out_shapes = tuple(core.DShapedArray(shape, dtype) for dtype in dtypes) - invars = [jaxpr_trace.getvar(obs)] - for dyn_dim in dynamic_shape: - invars.append(jaxpr_trace.getvar(dyn_dim)) + invars = [obs.val] + [t.val for t in dynamic_shape] params = {"static_shape": static_shape} - out_tracers = tuple(pe.DynamicJaxprTracer(jaxpr_trace, out_shape) for out_shape in out_shapes) + out_tracers = tuple( + pe.DynamicJaxprTracer(jaxpr_trace, out_shape, jaxpr_trace.frame.newvar(out_shape)) + for out_shape in out_shapes + ) eqn = pe.new_jaxpr_eqn( invars, - [jaxpr_trace.makevar(out_tracer) for out_tracer in out_tracers], + [out_tracer.val for out_tracer in out_tracers], primitive, params, jax.core.no_effects, @@ -2006,7 +2007,7 @@ def _cond_lowering( num_preds = len(branch_jaxprs) - 1 preds = preds_and_branch_args_plus_consts[:num_preds] branch_args_plus_consts = preds_and_branch_args_plus_consts[num_preds:] - flat_args_plus_consts = mlir.flatten_lowering_ir_args(branch_args_plus_consts) + flat_args_plus_consts = mlir.flatten_ir_values(branch_args_plus_consts) # recursively lower if-else chains to nested IfOps def emit_branches(preds, branch_jaxprs, ip): @@ -2106,7 +2107,7 @@ def _while_loop_lowering( preserve_dimensions: bool, ): loop_carry_types_plus_consts = [mlir.aval_to_ir_types(a)[0] for a in jax_ctx.avals_in] - flat_args_plus_consts = mlir.flatten_lowering_ir_args(iter_args_plus_consts) + flat_args_plus_consts = mlir.flatten_ir_values(iter_args_plus_consts) assert [val.type for val in flat_args_plus_consts] == loop_carry_types_plus_consts # split the argument list into 3 separate groups diff --git a/frontend/catalyst/jax_primitives_utils.py b/frontend/catalyst/jax_primitives_utils.py index 1227b7f1b5..31f02fe30c 100644 --- a/frontend/catalyst/jax_primitives_utils.py +++ b/frontend/catalyst/jax_primitives_utils.py @@ -143,11 +143,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False): kwargs["name"] = name kwargs["jaxpr"] = call_jaxpr kwargs["effects"] = [] - kwargs["name_stack"] = ctx.name_stack - # Make the visibility of the function public=True + # Make the visibility of the function main_function=True # to avoid elimination by the compiler - kwargs["public"] = public + kwargs["main_function"] = public func_op = mlir.lower_jaxpr_to_fun(**kwargs) @@ -269,7 +268,7 @@ def create_call_op(ctx, func_op, *args): """Create a func::CallOp from JAXPR.""" output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out)) flat_output_types = util.flatten(output_types) - mlir_args = mlir.flatten_lowering_ir_args(args) + mlir_args = mlir.flatten_ir_values(args) symbol_ref = get_symbolref(ctx, func_op) is_call_same_module = ctx.module_context.module.operation == func_op.parent constructor = CallOp if is_call_same_module else LaunchKernelOp diff --git a/frontend/catalyst/jax_tracer.py b/frontend/catalyst/jax_tracer.py index 82325b7808..a61a6c0706 100644 --- a/frontend/catalyst/jax_tracer.py +++ b/frontend/catalyst/jax_tracer.py @@ -554,7 +554,7 @@ def bind_overwrite_classical_tracers( for i, t in enumerate(out_expanded_tracers): # We look for what were the previous output tracers. # If they haven't changed, then we leave them unchanged. - if trace.getvar(t) in jaxpr_variables: + if t.val in jaxpr_variables: continue # If the variable cannot be found in the current frame @@ -596,7 +596,7 @@ def bind_overwrite_classical_tracers( # qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs) # # So it should be safe to cache the tracers as we are doing it. - eqn.outvars[i] = trace.getvar(t) + eqn.outvars[i] = t.val # Now, the output variables can be considered as part of the current frame. # This allows us to avoid importing all equations again next time. diff --git a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py index 4d57bdf15b..afa58b851f 100644 --- a/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py +++ b/frontend/test/pytest/from_plxpr/test_from_plxpr_qubit_handler.py @@ -274,8 +274,8 @@ def test_simple_gate(self): # Also check with actual jaxpr variables with take_current_trace() as trace: gate_out_qubits = trace.frame.eqns[-1].outvars - assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0] - assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1] + assert qubit_handler[0].val == gate_out_qubits[0] + assert qubit_handler[1].val == gate_out_qubits[1] def test_iter(self): """Test __iter__ in the qreg manager""" @@ -316,8 +316,8 @@ def test_chained_gate(self): # Also check with actual jaxpr variables with take_current_trace() as trace: gate_out_qubits = trace.frame.eqns[-1].outvars - assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0] - assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1] + assert qubit_handler[0].val == gate_out_qubits[0] + assert qubit_handler[1].val == gate_out_qubits[1] def test_insert_all_dangling_qubits(self): """ diff --git a/frontend/test/pytest/test_debug.py b/frontend/test/pytest/test_debug.py index 93376c4e05..5f3bee511f 100644 --- a/frontend/test/pytest/test_debug.py +++ b/frontend/test/pytest/test_debug.py @@ -587,7 +587,7 @@ def test_no_options_to_mlir_opt(self): module { func.func @foo() { %c = stablehlo.constant dense<0> : tensor - return + return } } """