Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
jax=cd1b9520bd2e1b0dd451281f06e98f030347aa8b
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203
Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
sys.modules["mlir_quantum.ir"] = __import__("jaxlib.mlir.ir").mlir.ir
sys.modules["mlir_quantum._mlir_libs"] = __import__("jaxlib.mlir._mlir_libs").mlir._mlir_libs

from catalyst.jax_extras.patches import patch_primitives

patch_primitives()

from catalyst import debug, logging, passes
from catalyst.api_extensions import *
from catalyst.api_extensions import __all__ as _api_extension_list
Expand Down
22 changes: 11 additions & 11 deletions frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex
from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext
from jax._src.source_info_util import new_name_stack
from jax._src.util import wrap_name
from jax.extend.core import ClosedJaxpr
from jax.interpreters.mlir import (
AxisContext,
Expand Down Expand Up @@ -72,15 +70,13 @@ def jaxpr_to_mlir(func_name, jaxpr):
nrep = jaxpr_replicas(jaxpr)
effects = jax_ordered_effects.filter_in(jaxpr.effects)
axis_context = ReplicaAxisContext(AxisEnv(nrep, (), ()))
name_stack = new_name_stack(wrap_name("ok", "jit"))
module, context = custom_lower_jaxpr_to_module(
func_name="jit_" + func_name,
module_name=func_name,
jaxpr=jaxpr,
effects=effects,
platform="cpu",
axis_context=axis_context,
name_stack=name_stack,
)

return module, context
Expand All @@ -95,7 +91,6 @@ def custom_lower_jaxpr_to_module(
effects,
platform: str,
axis_context: AxisContext,
name_stack,
replicated_args=None,
arg_shardings=None,
result_shardings=None,
Expand Down Expand Up @@ -142,27 +137,32 @@ def custom_lower_jaxpr_to_module(
# XLA computation preserves the module name.
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)

# Use main_function=False to preserve the function name (e.g., "jit_func")
# instead of renaming it to "main"
lower_jaxpr_to_fun(
ctx,
func_name,
jaxpr,
effects,
public=True,
main_function=False,
replicated_args=replicated_args,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
name_stack=name_stack,
)

# Set the entry point function visibility to public and other functions to internal
worklist = [*ctx.module.body.operations]
while worklist:
op = worklist.pop()
func_name = str(op.name)
is_entry_point = func_name.startswith('"jit_')
if is_entry_point:
continue
if isinstance(op, FuncOp):
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")
if func_name.startswith('"jit_'):
# Keep entry point functions public
op.attributes["sym_visibility"] = ir.StringAttr.get("public")
else:
# Set non-entry functions to internal linkage
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")
if isinstance(op, ModuleOp):
worklist += [*op.body.operations]

Expand Down
73 changes: 73 additions & 0 deletions frontend/catalyst/jax_extras/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"_no_clean_up_dead_vars",
"_gather_shape_rule_dynamic",
"gather2_p",
"patch_primitives",
)


Expand Down Expand Up @@ -208,3 +209,75 @@ def _gather_shape_rule_dynamic(
sharding_rule=_gather_sharding_rule,
vma_rule=partial(standard_vma_rule, "gather"),
)


# TODO: Remove this patch when JAX/PennyLane are updated to use the new JAX 0.7+ API.
def patch_primitives():
"""Patch PennyLane/JAX primitives to make them compatible with JAX 0.7+.

JAX 0.7+ requires all primitive parameters to be hashable, but PennyLane
passes **lists** for some parameters like control_values, jaxpr_branches, etc.
This patch wraps the bind method to convert lists to tuples to make them hashable.
"""

def make_hashable(value):
"""Recursively convert lists to tuples to make them hashable."""
if isinstance(value, list):
return tuple(make_hashable(item) for item in value)
return value

try:
from pennylane.capture.primitives import ctrl_transform_prim
from pennylane.capture.primitives import cond_prim
from pennylane.ops.op_math.controlled import Controlled

original_ctrl_bind = ctrl_transform_prim.bind
original_cond_bind = cond_prim.bind
original_controlled_bind = Controlled._primitive.bind

def patched_ctrl_bind(*args, **kwargs):
# Convert control_values from list to tuple if present
if "control_values" in kwargs:
kwargs["control_values"] = make_hashable(kwargs["control_values"])
return original_ctrl_bind(*args, **kwargs)

def patched_cond_bind(*args, **kwargs):
# Convert list parameters to tuples
if "jaxpr_branches" in kwargs:
kwargs["jaxpr_branches"] = make_hashable(kwargs["jaxpr_branches"])
if "consts_slices" in kwargs:
kwargs["consts_slices"] = make_hashable(kwargs["consts_slices"])
return original_cond_bind(*args, **kwargs)

def patched_controlled_bind(*args, **kwargs):
# Convert control_values from list to tuple if present
if "control_values" in kwargs:
kwargs["control_values"] = make_hashable(kwargs["control_values"])
return original_controlled_bind(*args, **kwargs)

# Replace the bind method
ctrl_transform_prim.bind = patched_ctrl_bind
cond_prim.bind = patched_cond_bind
Controlled._primitive.bind = patched_controlled_bind

except ImportError:
pass

# patch DynamicJaxprTrace members: makevar and getvar
try:
from jax._src.interpreters import partial_eval as pe

def patched_makevar(self, tracer):
assert tracer.val is None, "a jaxpr variable must be created only once per tracer"
tracer.val = self.frame.newvar(tracer.aval)
return tracer.val

def patched_getvar(self, tracer): # pylint: disable=unused-argument
if var := tracer.val:
return var
raise jax.core.escaped_tracer_error(tracer)

pe.DynamicJaxprTrace.makevar = patched_makevar
pe.DynamicJaxprTrace.getvar = patched_getvar
except ImportError:
pass
12 changes: 5 additions & 7 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,8 @@ def trace_to_jaxpr(
def new_inner_tracer(trace: DynamicJaxprTrace, aval) -> DynamicJaxprTracer:
"""Create a JAX tracer tracing an abstract value ``aval`, without specifying its source
primitive."""
dt = DynamicJaxprTracer(trace, aval, current_source_info())
trace.frame.tracers.append(dt)
trace.frame.tracer_to_var[id(dt)] = trace.frame.newvar(aval)
atom = trace.frame.newvar(aval)
dt = DynamicJaxprTracer(trace, aval, atom, line_info=current_source_info())
return dt


Expand Down Expand Up @@ -907,12 +906,11 @@ def bind(self, *args, **params):
# `abstract_eval` returned `out_type` calculated for empty constants.
[],
tracers,
maker=lambda a: DynamicJaxprTracer(trace, a, source_info),
maker=lambda aval: new_inner_tracer(trace, aval),
)

invars = map(trace.getvar, tracers)
outvars = map(trace.makevar, out_tracers)

invars = map(lambda t: t.val, tracers)
outvars = map(lambda t: t.val, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info)
trace.frame.add_eqn(eqn)
return out_tracers if self.multiple_results else out_tracers.pop()
Expand Down
29 changes: 15 additions & 14 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
flat_output_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(args_and_consts),
mlir.flatten_ir_values(args_and_consts),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results
Expand Down Expand Up @@ -763,7 +763,7 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params):
gradient_result_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(func_args),
mlir.flatten_ir_values(func_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
Expand Down Expand Up @@ -816,8 +816,8 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params):
flat_output_types[len(flat_output_types) // 2 :],
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(func_args),
mlir.flatten_lowering_ir_args(tang_args),
mlir.flatten_ir_values(func_args),
mlir.flatten_ir_values(tang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
Expand Down Expand Up @@ -866,8 +866,8 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params):
vjp_result_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(func_args),
mlir.flatten_lowering_ir_args(cotang_args),
mlir.flatten_ir_values(func_args),
mlir.flatten_ir_values(cotang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
Expand Down Expand Up @@ -930,7 +930,7 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
return ZneOp(
flat_output_types,
symbol_ref,
mlir.flatten_lowering_ir_args(args_and_consts),
mlir.flatten_ir_values(args_and_consts),
_folding_attribute(ctx, folding),
num_folds,
).results
Expand Down Expand Up @@ -1686,17 +1686,18 @@ def custom_measurement_staging_rule(
else:
out_shapes = tuple(core.DShapedArray(shape, dtype) for dtype in dtypes)

invars = [jaxpr_trace.getvar(obs)]
for dyn_dim in dynamic_shape:
invars.append(jaxpr_trace.getvar(dyn_dim))
invars = [obs.val] + [t.val for t in dynamic_shape]

params = {"static_shape": static_shape}

out_tracers = tuple(pe.DynamicJaxprTracer(jaxpr_trace, out_shape) for out_shape in out_shapes)
out_tracers = tuple(
pe.DynamicJaxprTracer(jaxpr_trace, out_shape, jaxpr_trace.frame.newvar(out_shape))
for out_shape in out_shapes
)

eqn = pe.new_jaxpr_eqn(
invars,
[jaxpr_trace.makevar(out_tracer) for out_tracer in out_tracers],
[out_tracer.val for out_tracer in out_tracers],
primitive,
params,
jax.core.no_effects,
Expand Down Expand Up @@ -2006,7 +2007,7 @@ def _cond_lowering(
num_preds = len(branch_jaxprs) - 1
preds = preds_and_branch_args_plus_consts[:num_preds]
branch_args_plus_consts = preds_and_branch_args_plus_consts[num_preds:]
flat_args_plus_consts = mlir.flatten_lowering_ir_args(branch_args_plus_consts)
flat_args_plus_consts = mlir.flatten_ir_values(branch_args_plus_consts)

# recursively lower if-else chains to nested IfOps
def emit_branches(preds, branch_jaxprs, ip):
Expand Down Expand Up @@ -2106,7 +2107,7 @@ def _while_loop_lowering(
preserve_dimensions: bool,
):
loop_carry_types_plus_consts = [mlir.aval_to_ir_types(a)[0] for a in jax_ctx.avals_in]
flat_args_plus_consts = mlir.flatten_lowering_ir_args(iter_args_plus_consts)
flat_args_plus_consts = mlir.flatten_ir_values(iter_args_plus_consts)
assert [val.type for val in flat_args_plus_consts] == loop_carry_types_plus_consts

# split the argument list into 3 separate groups
Expand Down
7 changes: 3 additions & 4 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
kwargs["name"] = name
kwargs["jaxpr"] = call_jaxpr
kwargs["effects"] = []
kwargs["name_stack"] = ctx.name_stack

# Make the visibility of the function public=True
# Make the visibility of the function main_function=True
# to avoid elimination by the compiler
kwargs["public"] = public
kwargs["main_function"] = public

func_op = mlir.lower_jaxpr_to_fun(**kwargs)

Expand Down Expand Up @@ -269,7 +268,7 @@ def create_call_op(ctx, func_op, *args):
"""Create a func::CallOp from JAXPR."""
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
mlir_args = mlir.flatten_lowering_ir_args(args)
mlir_args = mlir.flatten_ir_values(args)
symbol_ref = get_symbolref(ctx, func_op)
is_call_same_module = ctx.module_context.module.operation == func_op.parent
constructor = CallOp if is_call_same_module else LaunchKernelOp
Expand Down
4 changes: 2 additions & 2 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def bind_overwrite_classical_tracers(
for i, t in enumerate(out_expanded_tracers):
# We look for what were the previous output tracers.
# If they haven't changed, then we leave them unchanged.
if trace.getvar(t) in jaxpr_variables:
if t.val in jaxpr_variables:
continue

# If the variable cannot be found in the current frame
Expand Down Expand Up @@ -596,7 +596,7 @@ def bind_overwrite_classical_tracers(
# qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs)
#
# So it should be safe to cache the tracers as we are doing it.
eqn.outvars[i] = trace.getvar(t)
eqn.outvars[i] = t.val

# Now, the output variables can be considered as part of the current frame.
# This allows us to avoid importing all equations again next time.
Expand Down
35 changes: 35 additions & 0 deletions frontend/scan.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#!/bin/bash
# 保存為 check_jax_migration.sh

echo "🔍 JAX 0.7.0 遷移檢查"
echo "===================="

echo -e "\n❌ P0 - 必須修改的問題:"
echo "----------------------"

echo -e "\n1. DynamicJaxprTracer 構造函數:"
rg "DynamicJaxprTracer\(" --type py -C 2

echo -e "\n2. trace.getvar() 調用:"
rg "\.getvar\(" --type py

echo -e "\n3. trace.makevar() 調用:"
rg "\.makevar\(" --type py

echo -e "\n4. tracer_to_var 使用:"
rg "tracer_to_var" --type py

echo -e "\n5. frame.tracers 使用:"
rg "frame\.tracers" --type py

echo -e "\n\n⚠️ P1 - 需要檢查的代碼:"
echo "----------------------"

echo -e "\n6. 自定義 staging rules:"
rg "staging_rule|process_primitive" --type py

echo -e "\n7. JaxprStackFrame 使用:"
rg "JaxprStackFrame" --type py

echo -e "\n8. Tracer ID 使用:"
rg "id\(.*tracer" --type py
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def test_simple_gate(self):
# Also check with actual jaxpr variables
with take_current_trace() as trace:
gate_out_qubits = trace.frame.eqns[-1].outvars
assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0]
assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1]
assert qubit_handler[0].val == gate_out_qubits[0]
assert qubit_handler[1].val == gate_out_qubits[1]

def test_iter(self):
"""Test __iter__ in the qreg manager"""
Expand Down Expand Up @@ -316,8 +316,8 @@ def test_chained_gate(self):
# Also check with actual jaxpr variables
with take_current_trace() as trace:
gate_out_qubits = trace.frame.eqns[-1].outvars
assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0]
assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1]
assert qubit_handler[0].val == gate_out_qubits[0]
assert qubit_handler[1].val == gate_out_qubits[1]

def test_insert_all_dangling_qubits(self):
"""
Expand Down
Loading
Loading