2626import jax
2727import pennylane as qml
2828from jax ._src .tree_util import tree_flatten
29+ from jax .api_util import debug_info
2930from jax .core import get_aval
3031from pennylane import QueuingManager
3132from pennylane .operation import Operator
@@ -146,7 +147,7 @@ def circuit():
146147 EvaluationContext .check_is_quantum_tracing (
147148 "catalyst.measure can only be used from within a qml.qnode."
148149 )
149- ctx = EvaluationContext .get_main_tracing_context ()
150+ cur_trace = EvaluationContext .get_current_trace ()
150151 wires = list (wires ) if isinstance (wires , (list , tuple )) else [wires ]
151152 if len (wires ) != 1 :
152153 raise TypeError (f"Only one element is supported for the 'wires' parameter, got { wires } ." )
@@ -161,7 +162,7 @@ def circuit():
161162 if postselect is not None and postselect not in [0 , 1 ]:
162163 raise TypeError (f"postselect must be '0' or '1', got { postselect } " )
163164
164- m = new_inner_tracer (ctx . trace , get_aval (True ))
165+ m = new_inner_tracer (cur_trace , get_aval (True ))
165166 MidCircuitMeasure (
166167 in_classical_tracers = in_classical_tracers ,
167168 out_classical_tracers = [m ],
@@ -344,14 +345,12 @@ def trace_quantum(self, ctx, device, trace, qrp, postselect_mode=None) -> QRegPr
344345 qubit = qrp .extract (self .wires )[0 ]
345346 if postselect_mode == "hw-like" :
346347 qubit2 = self .bind_overwrite_classical_tracers (
347- ctx ,
348348 trace ,
349349 in_expanded_tracers = [qubit ],
350350 out_expanded_tracers = self .out_classical_tracers ,
351351 )
352352 else :
353353 qubit2 = self .bind_overwrite_classical_tracers (
354- ctx ,
355354 trace ,
356355 in_expanded_tracers = [qubit ],
357356 out_expanded_tracers = self .out_classical_tracers ,
@@ -405,12 +404,14 @@ def __call__(self, *args, **kwargs):
405404 return create_adjoint_op (base_op , self .lazy )
406405
407406 tracing_artifacts = self .trace_body (args , kwargs )
408-
409- return HybridAdjoint (* tracing_artifacts )
407+ dbg = tracing_artifacts [ 3 ]
408+ return HybridAdjoint (* tracing_artifacts [ 0 : 3 ], debug_info = dbg )
410409
411410 def trace_body (self , args , kwargs ):
412411 """Generate a HybridOpRegion for use by Catalyst."""
413412
413+ dbg = debug_info ("AdjointCallable" , self .target , args , kwargs )
414+
414415 # Allow the creation of HybridAdjoint instances outside of any contexts.
415416 # Don't create a JAX context here as otherwise we could be dealing with escaped tracers.
416417 if not EvaluationContext .is_tracing ():
@@ -420,18 +421,17 @@ def trace_body(self, args, kwargs):
420421
421422 adjoint_region = HybridOpRegion (None , quantum_tape , [], [])
422423
423- return [], [], [adjoint_region ]
424+ return [], [], [adjoint_region ], dbg
424425
425426 # Create a nested jaxpr scope for the body of the adjoint.
426- ctx = EvaluationContext .get_main_tracing_context ()
427- with EvaluationContext .frame_tracing_context (ctx ) as inner_trace :
427+ with EvaluationContext .frame_tracing_context (debug_info = dbg ) as inner_trace :
428428 in_classical_tracers , _ = tree_flatten ((args , kwargs ))
429- wffa , in_avals , _ , _ = deduce_avals (self .target , args , kwargs )
429+ wffa , in_avals , _ , _ = deduce_avals (self .target , args , kwargs , debug_info = dbg )
430430 arg_classical_tracers = _input_type_to_tracers (inner_trace .new_arg , in_avals )
431431 with QueuingManager .stop_recording (), QuantumTape () as quantum_tape :
432- # FIXME: move all full_raise calls into a separate function
432+ # FIXME: move all to_jaxpr_tracer calls into a separate function
433433 res_classical_tracers = [
434- inner_trace .full_raise (t )
434+ inner_trace .to_jaxpr_tracer (t )
435435 for t in wffa .call_wrapped (* arg_classical_tracers )
436436 if isinstance (t , DynamicJaxprTracer )
437437 ]
@@ -442,7 +442,7 @@ def trace_body(self, args, kwargs):
442442 inner_trace , quantum_tape , arg_classical_tracers , res_classical_tracers
443443 )
444444
445- return in_classical_tracers , [], [adjoint_region ]
445+ return in_classical_tracers , [], [adjoint_region ], wffa . debug_info
446446
447447
448448class HybridAdjoint (HybridOp ):
@@ -457,19 +457,20 @@ def trace_quantum(self, ctx, device, _trace, qrp) -> QRegPromise:
457457 body_trace = op .regions [0 ].trace
458458 body_tape = op .regions [0 ].quantum_tape
459459 res_classical_tracers = op .regions [0 ].res_classical_tracers
460+ dbg = op .debug_info
460461
461462 # Handle ops that were instantiated outside of a tracing context.
462463 if body_trace is None :
463- frame_ctx = EvaluationContext .frame_tracing_context (ctx )
464+ frame_ctx = EvaluationContext .frame_tracing_context (debug_info = dbg )
464465 else :
465- frame_ctx = EvaluationContext .frame_tracing_context (ctx , body_trace )
466+ frame_ctx = EvaluationContext .frame_tracing_context (body_trace )
466467
467468 with frame_ctx as body_trace :
468469 qreg_in = _input_type_to_tracers (body_trace .new_arg , [AbstractQreg ()])[0 ]
469470 qrp_out = trace_quantum_operations (body_tape , device , qreg_in , ctx , body_trace )
470471 qreg_out = qrp_out .actualize ()
471- body_jaxpr , _ , body_consts = ctx . frames [ body_trace ] .to_jaxpr2 (
472- res_classical_tracers + [qreg_out ]
472+ body_jaxpr , _ , body_consts = body_trace . frame .to_jaxpr2 (
473+ res_classical_tracers + [qreg_out ], dbg
473474 )
474475
475476 qreg = qrp .actualize ()
@@ -657,9 +658,9 @@ def ctrl_distribute(
657658
658659 # Allow decompositions outside of a Catalyst context.
659660 if EvaluationContext .is_tracing ():
660- ctx = EvaluationContext .get_main_tracing_context ()
661+ cur_trace = EvaluationContext .get_current_trace ()
661662 else :
662- ctx = None
663+ cur_trace = None
663664
664665 new_ops = []
665666 for op in tape .operations :
@@ -675,8 +676,8 @@ def ctrl_distribute(
675676 else :
676677 for region in [region for region in op .regions if region .quantum_tape is not None ]:
677678 # Re-enter a JAXPR frame but do not create a new one is none exists.
678- if ctx and region .trace :
679- trace_manager = EvaluationContext .frame_tracing_context (ctx , region .trace )
679+ if cur_trace and region .trace :
680+ trace_manager = EvaluationContext .frame_tracing_context (region .trace )
680681 else :
681682 trace_manager = nullcontext
682683
0 commit comments