Skip to content

Commit a0622ac

Browse files
authored
Set default diff_method to adjoint for lightning devices (#1961)
**Context:** Currently parameter-shift is used by default for differentiation, when the `diff_method` is not provided. This is the most universal method, but not the most performant, especially on lightning devices which support adjoint differentiation (for expval) **Description of the Change:** For lightning devices, set `diff_method` to `adjoint` when it is set to `best`, which is the default argument. Note: currently only a single expval is supported - this needs to be fixed. See [discussion](#1961 (comment)) **Benefits:** Significantly faster differentiation performance on lightning devices. **Possible Drawbacks:** **Related GitHub Issues:** #1171 [sc-74760]
1 parent 90e5d0b commit a0622ac

File tree

3 files changed

+208
-96
lines changed

3 files changed

+208
-96
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@
44

55
<h3>Improvements 🛠</h3>
66

7+
* Adjoint differentiation is used by default when executing on lightning devices, significantly reduces gradient computation time.
8+
[(#1961)](https://github.com/PennyLaneAI/catalyst/pull/1961)
9+
710
* Added `detensorizefunctionboundary` pass to remove scalar tensors across function boundaries and enabled `symbol-dce` pass to remove dead functions, reducing the number of instructions for compilation.
811
[(#1904)](https://github.com/PennyLaneAI/catalyst/pull/1904)
912

frontend/catalyst/jax_primitives_utils.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,7 @@
2323
from jax.interpreters import mlir
2424
from jaxlib.mlir.dialects.builtin import ModuleOp
2525
from jaxlib.mlir.dialects.func import CallOp
26-
from mlir_quantum.dialects._transform_ops_gen import (
27-
ApplyRegisteredPassOp,
28-
NamedSequenceOp,
29-
YieldOp,
30-
)
26+
from mlir_quantum.dialects._transform_ops_gen import ApplyRegisteredPassOp, NamedSequenceOp, YieldOp
3127
from mlir_quantum.dialects.catalyst import LaunchKernelOp
3228

3329

@@ -113,13 +109,30 @@ def lower_callable_to_funcop(ctx, callable_, call_jaxpr):
113109

114110
if isinstance(callable_, qml.QNode):
115111
func_op.attributes["qnode"] = ir.UnitAttr.get()
116-
# "best", the default option in PennyLane, chooses backprop on the device
117-
# if supported and parameter-shift otherwise. Emulating the same behaviour
118-
# would require generating code to query the device.
119-
# For simplicity, Catalyst instead defaults to parameter-shift.
120-
diff_method = (
121-
"parameter-shift" if callable_.diff_method == "best" else str(callable_.diff_method)
122-
)
112+
113+
diff_method = str(callable_.diff_method)
114+
115+
if diff_method == "best":
116+
117+
def only_single_expval():
118+
found_expval = False
119+
for eqn in call_jaxpr.eqns:
120+
name = eqn.primitive.name
121+
if name in {"probs", "counts", "sample"}:
122+
return False
123+
elif name == "expval":
124+
if found_expval:
125+
return False
126+
found_expval = True
127+
return True
128+
129+
device_name = getattr(getattr(callable_, "device", None), "name", None)
130+
131+
if device_name and "lightning" in device_name and only_single_expval():
132+
diff_method = "adjoint"
133+
else:
134+
diff_method = "parameter-shift"
135+
123136
func_op.attributes["diff_method"] = ir.StringAttr.get(diff_method)
124137

125138
return func_op

0 commit comments

Comments
 (0)