Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .dep-versions
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.
# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.2
jax=cd1b9520bd2e1b0dd451281f06e98f030347aa8b
stablehlo=0a4440a5c8de45c4f9649bf3eb4913bf3f97da0d
llvm=113f01aa82d055410f22a9d03b3468fa68600589
enzyme=v0.0.203
Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
sys.modules["mlir_quantum.ir"] = __import__("jaxlib.mlir.ir").mlir.ir
sys.modules["mlir_quantum._mlir_libs"] = __import__("jaxlib.mlir._mlir_libs").mlir._mlir_libs

from catalyst.jax_extras.patches import patch_primitives

patch_primitives()

from catalyst import debug, logging, passes
from catalyst.api_extensions import *
from catalyst.api_extensions import __all__ as _api_extension_list
Expand Down
22 changes: 11 additions & 11 deletions frontend/catalyst/jax_extras/lowering.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from jax._src.effects import ordered_effects as jax_ordered_effects
from jax._src.interpreters.mlir import _module_name_regex
from jax._src.sharding_impls import AxisEnv, ReplicaAxisContext
from jax._src.source_info_util import new_name_stack
from jax._src.util import wrap_name
from jax.extend.core import ClosedJaxpr
from jax.interpreters.mlir import (
AxisContext,
Expand Down Expand Up @@ -72,15 +70,13 @@ def jaxpr_to_mlir(func_name, jaxpr):
nrep = jaxpr_replicas(jaxpr)
effects = jax_ordered_effects.filter_in(jaxpr.effects)
axis_context = ReplicaAxisContext(AxisEnv(nrep, (), ()))
name_stack = new_name_stack(wrap_name("ok", "jit"))
module, context = custom_lower_jaxpr_to_module(
func_name="jit_" + func_name,
module_name=func_name,
jaxpr=jaxpr,
effects=effects,
platform="cpu",
axis_context=axis_context,
name_stack=name_stack,
)

return module, context
Expand All @@ -95,7 +91,6 @@ def custom_lower_jaxpr_to_module(
effects,
platform: str,
axis_context: AxisContext,
name_stack,
replicated_args=None,
arg_shardings=None,
result_shardings=None,
Expand Down Expand Up @@ -142,27 +137,32 @@ def custom_lower_jaxpr_to_module(
# XLA computation preserves the module name.
module_name = _module_name_regex.sub("_", module_name)
ctx.module.operation.attributes["sym_name"] = ir.StringAttr.get(module_name)

# Use main_function=False to preserve the function name (e.g., "jit_func")
# instead of renaming it to "main"
lower_jaxpr_to_fun(
ctx,
func_name,
jaxpr,
effects,
public=True,
main_function=False,
replicated_args=replicated_args,
arg_shardings=arg_shardings,
result_shardings=result_shardings,
name_stack=name_stack,
)

# Set the entry point function visibility to public and other functions to internal
worklist = [*ctx.module.body.operations]
while worklist:
op = worklist.pop()
func_name = str(op.name)
is_entry_point = func_name.startswith('"jit_')
if is_entry_point:
continue
if isinstance(op, FuncOp):
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")
if func_name.startswith('"jit_'):
# Keep entry point functions public
op.attributes["sym_visibility"] = ir.StringAttr.get("public")
else:
# Set non-entry functions to internal linkage
op.attributes["llvm.linkage"] = ir.Attribute.parse("#llvm.linkage<internal>")
if isinstance(op, ModuleOp):
worklist += [*op.body.operations]

Expand Down
73 changes: 73 additions & 0 deletions frontend/catalyst/jax_extras/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
"_no_clean_up_dead_vars",
"_gather_shape_rule_dynamic",
"gather2_p",
"patch_primitives",
)


Expand Down Expand Up @@ -208,3 +209,75 @@ def _gather_shape_rule_dynamic(
sharding_rule=_gather_sharding_rule,
vma_rule=partial(standard_vma_rule, "gather"),
)


# TODO: Remove this patch when JAX/PennyLane are updated to use the new JAX 0.7+ API.
def patch_primitives():
"""Patch PennyLane/JAX primitives to make them compatible with JAX 0.7+.

JAX 0.7+ requires all primitive parameters to be hashable, but PennyLane
passes **lists** for some parameters like control_values, jaxpr_branches, etc.
This patch wraps the bind method to convert lists to tuples to make them hashable.
"""

def make_hashable(value):
"""Recursively convert lists to tuples to make them hashable."""
if isinstance(value, list):
return tuple(make_hashable(item) for item in value)
return value

try:
from pennylane.capture.primitives import ctrl_transform_prim
from pennylane.capture.primitives import cond_prim
from pennylane.ops.op_math.controlled import Controlled

original_ctrl_bind = ctrl_transform_prim.bind
original_cond_bind = cond_prim.bind
original_controlled_bind = Controlled._primitive.bind

def patched_ctrl_bind(*args, **kwargs):
# Convert control_values from list to tuple if present
if "control_values" in kwargs:
kwargs["control_values"] = make_hashable(kwargs["control_values"])
return original_ctrl_bind(*args, **kwargs)

def patched_cond_bind(*args, **kwargs):
# Convert list parameters to tuples
if "jaxpr_branches" in kwargs:
kwargs["jaxpr_branches"] = make_hashable(kwargs["jaxpr_branches"])
if "consts_slices" in kwargs:
kwargs["consts_slices"] = make_hashable(kwargs["consts_slices"])
return original_cond_bind(*args, **kwargs)

def patched_controlled_bind(*args, **kwargs):
# Convert control_values from list to tuple if present
if "control_values" in kwargs:
kwargs["control_values"] = make_hashable(kwargs["control_values"])
return original_controlled_bind(*args, **kwargs)

# Replace the bind method
ctrl_transform_prim.bind = patched_ctrl_bind
cond_prim.bind = patched_cond_bind
Controlled._primitive.bind = patched_controlled_bind

except ImportError:
pass

# patch DynamicJaxprTrace members: makevar and getvar
try:
from jax._src.interpreters import partial_eval as pe

def patched_makevar(self, tracer):
assert tracer.val is None, "a jaxpr variable must be created only once per tracer"
tracer.val = self.frame.newvar(tracer.aval)
return tracer.val

def patched_getvar(self, tracer): # pylint: disable=unused-argument
if var := tracer.val:
return var
raise jax.core.escaped_tracer_error(tracer)

pe.DynamicJaxprTrace.makevar = patched_makevar
pe.DynamicJaxprTrace.getvar = patched_getvar
except ImportError:
pass
12 changes: 5 additions & 7 deletions frontend/catalyst/jax_extras/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,9 +427,8 @@ def trace_to_jaxpr(
def new_inner_tracer(trace: DynamicJaxprTrace, aval) -> DynamicJaxprTracer:
"""Create a JAX tracer tracing an abstract value ``aval`, without specifying its source
primitive."""
dt = DynamicJaxprTracer(trace, aval, current_source_info())
trace.frame.tracers.append(dt)
trace.frame.tracer_to_var[id(dt)] = trace.frame.newvar(aval)
atom = trace.frame.newvar(aval)
dt = DynamicJaxprTracer(trace, aval, atom, line_info=current_source_info())
return dt


Expand Down Expand Up @@ -907,12 +906,11 @@ def bind(self, *args, **params):
# `abstract_eval` returned `out_type` calculated for empty constants.
[],
tracers,
maker=lambda a: DynamicJaxprTracer(trace, a, source_info),
maker=lambda aval: new_inner_tracer(trace, aval),
)

invars = map(trace.getvar, tracers)
outvars = map(trace.makevar, out_tracers)

invars = map(lambda t: t.val, tracers)
outvars = map(lambda t: t.val, out_tracers)
eqn = new_jaxpr_eqn(invars, outvars, self, params, [], source_info)
trace.frame.add_eqn(eqn)
return out_tracers if self.multiple_results else out_tracers.pop()
Expand Down
29 changes: 15 additions & 14 deletions frontend/catalyst/jax_primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ def _grad_lowering(ctx, *args, jaxpr, fn, grad_params):
flat_output_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(args_and_consts),
mlir.flatten_ir_values(args_and_consts),
diffArgIndices=diffArgIndices,
finiteDiffParam=finiteDiffParam,
).results
Expand Down Expand Up @@ -763,7 +763,7 @@ def _value_and_grad_lowering(ctx, *args, jaxpr, fn, grad_params):
gradient_result_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(func_args),
mlir.flatten_ir_values(func_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
Expand Down Expand Up @@ -816,8 +816,8 @@ def _jvp_lowering(ctx, *args, jaxpr, fn, grad_params):
flat_output_types[len(flat_output_types) // 2 :],
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(func_args),
mlir.flatten_lowering_ir_args(tang_args),
mlir.flatten_ir_values(func_args),
mlir.flatten_ir_values(tang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
Expand Down Expand Up @@ -866,8 +866,8 @@ def _vjp_lowering(ctx, *args, jaxpr, fn, grad_params):
vjp_result_types,
ir.StringAttr.get(method),
symbol_ref,
mlir.flatten_lowering_ir_args(func_args),
mlir.flatten_lowering_ir_args(cotang_args),
mlir.flatten_ir_values(func_args),
mlir.flatten_ir_values(cotang_args),
diffArgIndices=ir.DenseIntElementsAttr.get(new_argnums),
finiteDiffParam=ir.FloatAttr.get(ir.F64Type.get(mlir_ctx), h) if h else None,
).results
Expand Down Expand Up @@ -930,7 +930,7 @@ def _zne_lowering(ctx, *args, folding, jaxpr, fn):
return ZneOp(
flat_output_types,
symbol_ref,
mlir.flatten_lowering_ir_args(args_and_consts),
mlir.flatten_ir_values(args_and_consts),
_folding_attribute(ctx, folding),
num_folds,
).results
Expand Down Expand Up @@ -1686,17 +1686,18 @@ def custom_measurement_staging_rule(
else:
out_shapes = tuple(core.DShapedArray(shape, dtype) for dtype in dtypes)

invars = [jaxpr_trace.getvar(obs)]
for dyn_dim in dynamic_shape:
invars.append(jaxpr_trace.getvar(dyn_dim))
invars = [obs.val] + [t.val for t in dynamic_shape]

params = {"static_shape": static_shape}

out_tracers = tuple(pe.DynamicJaxprTracer(jaxpr_trace, out_shape) for out_shape in out_shapes)
out_tracers = tuple(
pe.DynamicJaxprTracer(jaxpr_trace, out_shape, jaxpr_trace.frame.newvar(out_shape))
for out_shape in out_shapes
)

eqn = pe.new_jaxpr_eqn(
invars,
[jaxpr_trace.makevar(out_tracer) for out_tracer in out_tracers],
[out_tracer.val for out_tracer in out_tracers],
primitive,
params,
jax.core.no_effects,
Expand Down Expand Up @@ -2006,7 +2007,7 @@ def _cond_lowering(
num_preds = len(branch_jaxprs) - 1
preds = preds_and_branch_args_plus_consts[:num_preds]
branch_args_plus_consts = preds_and_branch_args_plus_consts[num_preds:]
flat_args_plus_consts = mlir.flatten_lowering_ir_args(branch_args_plus_consts)
flat_args_plus_consts = mlir.flatten_ir_values(branch_args_plus_consts)

# recursively lower if-else chains to nested IfOps
def emit_branches(preds, branch_jaxprs, ip):
Expand Down Expand Up @@ -2106,7 +2107,7 @@ def _while_loop_lowering(
preserve_dimensions: bool,
):
loop_carry_types_plus_consts = [mlir.aval_to_ir_types(a)[0] for a in jax_ctx.avals_in]
flat_args_plus_consts = mlir.flatten_lowering_ir_args(iter_args_plus_consts)
flat_args_plus_consts = mlir.flatten_ir_values(iter_args_plus_consts)
assert [val.type for val in flat_args_plus_consts] == loop_carry_types_plus_consts

# split the argument list into 3 separate groups
Expand Down
7 changes: 3 additions & 4 deletions frontend/catalyst/jax_primitives_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,11 +143,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
kwargs["name"] = name
kwargs["jaxpr"] = call_jaxpr
kwargs["effects"] = []
kwargs["name_stack"] = ctx.name_stack

# Make the visibility of the function public=True
# Make the visibility of the function main_function=True
# to avoid elimination by the compiler
kwargs["public"] = public
kwargs["main_function"] = public

func_op = mlir.lower_jaxpr_to_fun(**kwargs)

Expand Down Expand Up @@ -269,7 +268,7 @@ def create_call_op(ctx, func_op, *args):
"""Create a func::CallOp from JAXPR."""
output_types = list(map(mlir.aval_to_ir_types, ctx.avals_out))
flat_output_types = util.flatten(output_types)
mlir_args = mlir.flatten_lowering_ir_args(args)
mlir_args = mlir.flatten_ir_values(args)
symbol_ref = get_symbolref(ctx, func_op)
is_call_same_module = ctx.module_context.module.operation == func_op.parent
constructor = CallOp if is_call_same_module else LaunchKernelOp
Expand Down
4 changes: 2 additions & 2 deletions frontend/catalyst/jax_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -554,7 +554,7 @@ def bind_overwrite_classical_tracers(
for i, t in enumerate(out_expanded_tracers):
# We look for what were the previous output tracers.
# If they haven't changed, then we leave them unchanged.
if trace.getvar(t) in jaxpr_variables:
if t.val in jaxpr_variables:
continue

# If the variable cannot be found in the current frame
Expand Down Expand Up @@ -596,7 +596,7 @@ def bind_overwrite_classical_tracers(
# qrp2 = op.trace_quantum(ctx, device, trace, qrp, **kwargs)
#
# So it should be safe to cache the tracers as we are doing it.
eqn.outvars[i] = trace.getvar(t)
eqn.outvars[i] = t.val

# Now, the output variables can be considered as part of the current frame.
# This allows us to avoid importing all equations again next time.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -274,8 +274,8 @@ def test_simple_gate(self):
# Also check with actual jaxpr variables
with take_current_trace() as trace:
gate_out_qubits = trace.frame.eqns[-1].outvars
assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0]
assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1]
assert qubit_handler[0].val == gate_out_qubits[0]
assert qubit_handler[1].val == gate_out_qubits[1]

def test_iter(self):
"""Test __iter__ in the qreg manager"""
Expand Down Expand Up @@ -316,8 +316,8 @@ def test_chained_gate(self):
# Also check with actual jaxpr variables
with take_current_trace() as trace:
gate_out_qubits = trace.frame.eqns[-1].outvars
assert trace.frame.tracer_to_var[id(qubit_handler[0])] == gate_out_qubits[0]
assert trace.frame.tracer_to_var[id(qubit_handler[1])] == gate_out_qubits[1]
assert qubit_handler[0].val == gate_out_qubits[0]
assert qubit_handler[1].val == gate_out_qubits[1]

def test_insert_all_dangling_qubits(self):
"""
Expand Down
2 changes: 1 addition & 1 deletion frontend/test/pytest/test_debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ def test_no_options_to_mlir_opt(self):
module {
func.func @foo() {
%c = stablehlo.constant dense<0> : tensor<i64>
return
return
}
}
"""
Expand Down
Loading