2929from catalyst .jax_extras .lowering import get_mlir_attribute_from_pyval
3030
3131
32+ def _only_single_expval (call_jaxpr : core .ClosedJaxpr ) -> bool :
33+ found_expval = False
34+ for eqn in call_jaxpr .eqns :
35+ name = eqn .primitive .name
36+ if name in {"probs" , "counts" , "sample" }:
37+ return False
38+ elif name == "expval" :
39+ if found_expval :
40+ return False
41+ found_expval = True
42+ return True
43+
44+
45+ def _calculate_diff_method (qn : qml .QNode , call_jaxpr : core .ClosedJaxpr ):
46+ diff_method = str (qn .diff_method )
47+ if diff_method != "best" :
48+ return diff_method
49+
50+ device_name = getattr (getattr (qn , "device" , None ), "name" , None )
51+
52+ if device_name and "lightning" in device_name and _only_single_expval (call_jaxpr ):
53+ return "adjoint"
54+ return "parameter-shift"
55+
56+
3257def get_call_jaxpr (jaxpr ):
3358 """Extracts the `call_jaxpr` from a JAXPR if it exists.""" ""
3459 for eqn in jaxpr .eqns :
@@ -45,28 +70,36 @@ def get_call_equation(jaxpr):
4570 raise AssertionError ("No call_jaxpr found in the JAXPR." )
4671
4772
48- def lower_jaxpr (ctx , jaxpr , context = None ):
73+ def lower_jaxpr (ctx , jaxpr , metadata = None , fn = None ):
4974 """Lowers a call primitive jaxpr, may be either func_p or quantum_kernel_p
5075
5176 Args:
5277 ctx: LoweringRuleContext
5378 jaxpr: JAXPR to be lowered
54- context: additional context to distinguish different FuncOps
79+ metadata: additional metadata to distinguish different FuncOps
80+ fn (Callable | None): the function the jaxpr corresponds to. Used for naming and caching.
5581
5682 Returns:
5783 FuncOp
5884 """
59- equation = get_call_equation (jaxpr )
60- call_jaxpr = equation .params ["call_jaxpr" ]
61- callable_ = equation .params .get ("fn" )
62- if callable_ is None :
63- callable_ = equation .params .get ("qnode" )
64- pipeline = equation .params .get ("pipeline" )
65- return lower_callable (ctx , callable_ , call_jaxpr , pipeline = pipeline , context = context )
85+
86+ if fn is None or isinstance (fn , qml .QNode ):
87+ equation = get_call_equation (jaxpr )
88+ call_jaxpr = equation .params ["call_jaxpr" ]
89+ pipeline = equation .params .get ("pipeline" )
90+ callable_ = equation .params .get ("fn" )
91+ if callable_ is None :
92+ callable_ = equation .params .get ("qnode" , None )
93+ else :
94+ call_jaxpr = jaxpr
95+ pipeline = ()
96+ callable_ = fn
97+
98+ return lower_callable (ctx , callable_ , call_jaxpr , pipeline = pipeline , metadata = metadata )
6699
67100
68101# pylint: disable=too-many-arguments, too-many-positional-arguments
69- def lower_callable (ctx , callable_ , call_jaxpr , pipeline = None , context = None , public = False ):
102+ def lower_callable (ctx , callable_ , call_jaxpr , pipeline = (), metadata = None , public = False ):
70103 """Lowers _callable to MLIR.
71104
72105 If callable_ is a qnode, then we will first create a module, then
@@ -86,33 +119,33 @@ def lower_callable(ctx, callable_, call_jaxpr, pipeline=None, context=None, publ
86119 if pipeline is None :
87120 pipeline = tuple ()
88121
89- if not isinstance (callable_ , qml .QNode ):
90- return get_or_create_funcop (
91- ctx , callable_ , call_jaxpr , pipeline , context = context , public = public
92- )
93-
94- return get_or_create_qnode_funcop (ctx , callable_ , call_jaxpr , pipeline , context = context )
122+ if isinstance (callable_ , qml .QNode ):
123+ return get_or_create_qnode_funcop (ctx , callable_ , call_jaxpr , pipeline , metadata = metadata )
124+ return get_or_create_funcop (
125+ ctx , callable_ , call_jaxpr , pipeline , metadata = metadata , public = public
126+ )
95127
96128
97129# pylint: disable=too-many-arguments, too-many-positional-arguments
98- def get_or_create_funcop (ctx , callable_ , call_jaxpr , pipeline , context = None , public = False ):
130+ def get_or_create_funcop (ctx , callable_ , call_jaxpr , pipeline , metadata = None , public = False ):
99131 """Get funcOp from cache, or create it from scratch
100132
101133 Args:
102134 ctx: LoweringRuleContext
103135 callable_: python function
104136 call_jaxpr: jaxpr representing callable_
105- context : additional context to distinguish different FuncOps
137+ metadata : additional metadata to distinguish different FuncOps
106138 public: whether the visibility should be marked public
107139
108140 Returns:
109141 FuncOp
110142 """
111- if context is None :
112- context = tuple ()
113- key = (callable_ , * context , * pipeline )
114- if func_op := get_cached (ctx , key ):
115- return func_op
143+ if metadata is None :
144+ metadata = tuple ()
145+ key = (callable_ , * metadata , * pipeline )
146+ if callable_ is not None :
147+ if func_op := get_cached (ctx , key ):
148+ return func_op
116149 func_op = lower_callable_to_funcop (ctx , callable_ , call_jaxpr , public = public )
117150 cache (ctx , key , func_op )
118151 return func_op
@@ -135,10 +168,10 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
135168
136169 kwargs = {}
137170 kwargs ["ctx" ] = ctx .module_context
138- if not isinstance (callable_ , functools .partial ):
139- name = callable_ .__name__
140- else :
171+ if isinstance (callable_ , functools .partial ):
141172 name = callable_ .func .__name__ + ".partial"
173+ else :
174+ name = callable_ .__name__
142175
143176 kwargs ["name" ] = name
144177 kwargs ["jaxpr" ] = call_jaxpr
@@ -154,28 +187,7 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr, public=False):
154187 if isinstance (callable_ , qml .QNode ):
155188 func_op .attributes ["qnode" ] = ir .UnitAttr .get ()
156189
157- diff_method = str (callable_ .diff_method )
158-
159- if diff_method == "best" :
160-
161- def only_single_expval ():
162- found_expval = False
163- for eqn in call_jaxpr .eqns :
164- name = eqn .primitive .name
165- if name in {"probs" , "counts" , "sample" }:
166- return False
167- elif name == "expval" :
168- if found_expval :
169- return False
170- found_expval = True
171- return True
172-
173- device_name = getattr (getattr (callable_ , "device" , None ), "name" , None )
174-
175- if device_name and "lightning" in device_name and only_single_expval ():
176- diff_method = "adjoint"
177- else :
178- diff_method = "parameter-shift"
190+ diff_method = _calculate_diff_method (callable_ , call_jaxpr )
179191
180192 func_op .attributes ["diff_method" ] = ir .StringAttr .get (diff_method )
181193
@@ -195,7 +207,7 @@ def only_single_expval():
195207 return func_op
196208
197209
198- def get_or_create_qnode_funcop (ctx , callable_ , call_jaxpr , pipeline , context ):
210+ def get_or_create_qnode_funcop (ctx , callable_ , call_jaxpr , pipeline , metadata ):
199211 """A wrapper around lower_qnode_to_funcop that will cache the FuncOp.
200212
201213 Args:
@@ -205,11 +217,11 @@ def get_or_create_qnode_funcop(ctx, callable_, call_jaxpr, pipeline, context):
205217 Returns:
206218 FuncOp
207219 """
208- if context is None :
209- context = tuple ()
220+ if metadata is None :
221+ metadata = tuple ()
210222 if callable_ .static_argnums :
211223 return lower_qnode_to_funcop (ctx , callable_ , call_jaxpr , pipeline )
212- key = (callable_ , * context , * pipeline )
224+ key = (callable_ , * metadata , * pipeline )
213225 if func_op := get_cached (ctx , key ):
214226 return func_op
215227 func_op = lower_qnode_to_funcop (ctx , callable_ , call_jaxpr , pipeline )
0 commit comments