|
23 | 23 | from jax.interpreters import mlir |
24 | 24 | from jaxlib.mlir.dialects.builtin import ModuleOp |
25 | 25 | from jaxlib.mlir.dialects.func import CallOp |
26 | | -from mlir_quantum.dialects._transform_ops_gen import ( |
27 | | - ApplyRegisteredPassOp, |
28 | | - NamedSequenceOp, |
29 | | - YieldOp, |
30 | | -) |
| 26 | +from mlir_quantum.dialects._transform_ops_gen import ApplyRegisteredPassOp, NamedSequenceOp, YieldOp |
31 | 27 | from mlir_quantum.dialects.catalyst import LaunchKernelOp |
32 | 28 |
|
33 | 29 |
|
@@ -113,13 +109,30 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr): |
113 | 109 |
|
114 | 110 | if isinstance(callable_, qml.QNode): |
115 | 111 | func_op.attributes["qnode"] = ir.UnitAttr.get() |
116 | | - # "best", the default option in PennyLane, chooses backprop on the device |
117 | | - # if supported and parameter-shift otherwise. Emulating the same behaviour |
118 | | - # would require generating code to query the device. |
119 | | - # For simplicity, Catalyst instead defaults to parameter-shift. |
120 | | - diff_method = ( |
121 | | - "parameter-shift" if callable_.diff_method == "best" else str(callable_.diff_method) |
122 | | - ) |
| 112 | + |
| 113 | + diff_method = str(callable_.diff_method) |
| 114 | + |
| 115 | + if diff_method == "best": |
| 116 | + |
| 117 | + def only_single_expval(): |
| 118 | + found_expval = False |
| 119 | + for eqn in call_jaxpr.eqns: |
| 120 | + name = eqn.primitive.name |
| 121 | + if name in {"probs", "counts", "sample"}: |
| 122 | + return False |
| 123 | + elif name == "expval": |
| 124 | + if found_expval: |
| 125 | + return False |
| 126 | + found_expval = True |
| 127 | + return True |
| 128 | + |
| 129 | + device_name = getattr(getattr(callable_, "device", None), "name", None) |
| 130 | + |
| 131 | + if device_name and "lightning" in device_name and only_single_expval(): |
| 132 | + diff_method = "adjoint" |
| 133 | + else: |
| 134 | + diff_method = "parameter-shift" |
| 135 | + |
123 | 136 | func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method) |
124 | 137 |
|
125 | 138 | return func_op |
|
0 commit comments