Skip to content

Commit 46734c9

Browse files
authored
[Capture] Improve cond by storing missing false branches as empty jaxpr (#8080)
**Context:** Right now we store missing false branches in a cond equation with `None`. This PR simply stores them as jaxpr with no output instead. This just means we have fewer logical branches we need to consider when handling `cond` equations, which will make maintainence easier. **Description of the Change:** Store missing false branches as empty jaxpr instead of `None`. **Benefits:** A reduction in logical complexity. **Possible Drawbacks:** **Related GitHub Issues:** [sc-97723]
1 parent 23d998e commit 46734c9

File tree

9 files changed

+63
-80
lines changed

9 files changed

+63
-80
lines changed

doc/releases/changelog-dev.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@
453453

454454
<h3>Internal changes ⚙️</h3>
455455

456+
* The `cond` primitive with program capture no longer stores missing false branches as `None`, instead storing them
457+
as jaxprs with no output.
458+
[(#8080)](https://github.com/PennyLaneAI/pennylane/pull/8080)
459+
456460
* Removed unnecessary execution tests along with accuracy validation in `tests/ops/functions/test_map_wires.py`.
457461
[(#8032)](https://github.com/PennyLaneAI/pennylane/pull/8032)
458462

pennylane/capture/base_interpreter.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -533,15 +533,11 @@ def handle_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
533533

534534
for const_slice, jaxpr in zip(consts_slices, jaxpr_branches):
535535
consts = invals[const_slice]
536-
if jaxpr is None:
537-
new_jaxprs.append(None)
538-
new_consts_slices.append(slice(0, 0))
539-
else:
540-
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
541-
new_jaxprs.append(new_jaxpr.jaxpr)
542-
new_consts.extend(new_jaxpr.consts)
543-
new_consts_slices.append(slice(end_const_ind, end_const_ind + len(new_jaxpr.consts)))
544-
end_const_ind += len(new_jaxpr.consts)
536+
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
537+
new_jaxprs.append(new_jaxpr.jaxpr)
538+
new_consts.extend(new_jaxpr.consts)
539+
new_consts_slices.append(slice(end_const_ind, end_const_ind + len(new_jaxpr.consts)))
540+
end_const_ind += len(new_jaxpr.consts)
545541

546542
new_args_slice = slice(end_const_ind, None)
547543
return cond_prim.bind(
@@ -694,7 +690,7 @@ def flattened_cond(self, *invals, jaxpr_branches, consts_slices, args_slice):
694690
raise NotImplementedError(
695691
f"{self} does not yet support jitting cond with abstract conditions."
696692
)
697-
if pred and jaxpr is not None:
693+
if pred:
698694
return copy(self).eval(jaxpr, consts, *args)
699695
return ()
700696

pennylane/decomposition/collect_resource_ops.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -83,11 +83,10 @@ def explore_all_branches(self, *invals, jaxpr_branches, consts_slices, args_slic
8383
outvals = ()
8484
for _, jaxpr, consts_slice in zip(conditions, jaxpr_branches, consts_slices):
8585
consts = invals[consts_slice]
86-
if jaxpr is not None:
87-
dummy = copy(self).eval(jaxpr, consts, *args)
88-
# The cond_prim may or may not expect outvals, so we need to check whether
89-
# the first branch returns something significant. If so, we use the return
90-
# value of the first branch as the outvals of this cond_prim.
91-
if dummy and not outvals:
92-
outvals = dummy
86+
dummy = copy(self).eval(jaxpr, consts, *args)
87+
# The cond_prim may or may not expect outvals, so we need to check whether
88+
# the first branch returns something significant. If so, we use the return
89+
# value of the first branch as the outvals of this cond_prim.
90+
if dummy and not outvals:
91+
outvals = dummy
9392
return outvals

pennylane/ops/op_math/condition.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,10 @@ def new_fn(*args, **kwargs):
7575
return fn
7676

7777

78+
def _empty_return_fn(*_, **__):
79+
return None
80+
81+
7882
class Conditional(SymbolicOp, Operation):
7983
"""A Conditional Operation.
8084
@@ -286,19 +290,17 @@ def __call_capture_enabled(self, *args, **kwargs):
286290
raise ValueError(f"Condition predicate must be a scalar. Got {pred_shape}.")
287291
conditions.append(pred)
288292
if fn is None:
289-
jaxpr_branches.append(None)
290-
consts_slices.append(slice(0, 0))
291-
else:
292-
f = FlatFn(functools.partial(fn, **kwargs))
293-
if jax.config.jax_dynamic_shapes:
294-
f = _add_abstract_shapes(f)
295-
jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args)
296-
jaxpr_branches.append(jaxpr.jaxpr)
297-
consts_slices.append(slice(end_const_ind, end_const_ind + len(jaxpr.consts)))
298-
consts += jaxpr.consts
299-
end_const_ind += len(jaxpr.consts)
300-
301-
_validate_jaxpr_returns(jaxpr_branches)
293+
fn = _empty_return_fn
294+
f = FlatFn(functools.partial(fn, **kwargs))
295+
if jax.config.jax_dynamic_shapes:
296+
f = _add_abstract_shapes(f)
297+
jaxpr = jax.make_jaxpr(f, abstracted_axes=abstracted_axes)(*args)
298+
jaxpr_branches.append(jaxpr.jaxpr)
299+
consts_slices.append(slice(end_const_ind, end_const_ind + len(jaxpr.consts)))
300+
consts += jaxpr.consts
301+
end_const_ind += len(jaxpr.consts)
302+
303+
_validate_jaxpr_returns(jaxpr_branches, self.otherwise_fn)
302304
flat_args, _ = jax.tree_util.tree_flatten(args)
303305
results = cond_prim.bind(
304306
*conditions,
@@ -746,21 +748,15 @@ def _validate_abstract_values(
746748
_aval_mismatch_error(branch_type, branch_index, i, outval, expected_outval)
747749

748750

749-
def _validate_jaxpr_returns(jaxpr_branches):
751+
def _validate_jaxpr_returns(jaxpr_branches, false_fn):
750752
out_avals_true = [out.aval for out in jaxpr_branches[0].outvars]
751-
for idx, jaxpr_branch in enumerate(jaxpr_branches):
752-
753-
if idx == 0:
754-
continue
755753

756-
if jaxpr_branch is None:
757-
if out_avals_true:
758-
raise ValueError(
759-
"The false branch must be provided if the true branch returns any variables"
760-
)
761-
# this is tested, but coverage does not pick it up
762-
continue # pragma: no cover
754+
if false_fn is None and out_avals_true:
755+
raise ValueError(
756+
"The false branch must be provided if the true branch returns any variables"
757+
)
763758

759+
for idx, jaxpr_branch in enumerate(jaxpr_branches[1:], start=1):
764760
out_avals_branch = [out.aval for out in jaxpr_branch.outvars]
765761
branch_type = "elif" if idx < len(jaxpr_branches) - 1 else "false"
766762
_validate_abstract_values(out_avals_branch, out_avals_true, branch_type, idx - 1)
@@ -805,8 +801,6 @@ def _(*all_args, jaxpr_branches, consts_slices, args_slice):
805801

806802
for pred, jaxpr, const_slice in zip(conditions, jaxpr_branches, consts_slices):
807803
consts = all_args[const_slice]
808-
if jaxpr is None:
809-
continue
810804
if isinstance(pred, qml.measurements.MeasurementValue):
811805

812806
with qml.queuing.AnnotatedQueue() as q:

pennylane/tape/plxpr_conversion.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,6 @@ def _(self, *all_args, jaxpr_branches, consts_slices, args_slice):
150150

151151
for pred, jaxpr, const_slice in zip(conditions, jaxpr_branches, consts_slices):
152152
consts = all_args[const_slice]
153-
if jaxpr is None:
154-
continue
155153
if isinstance(pred, qml.measurements.MeasurementValue):
156154
if jaxpr.outvars:
157155
outvals = [v.aval for v in jaxpr.outvars]

pennylane/transforms/defer_measurements.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -408,11 +408,6 @@ def _(self, *invals, jaxpr_branches, consts_slices, args_slice):
408408
args = invals[args_slice]
409409

410410
for i, (condition, jaxpr) in enumerate(zip(conditions, jaxpr_branches, strict=True)):
411-
if jaxpr is None:
412-
# If a false branch isn't provided, the jaxpr corresponding to the condition
413-
# for the false branch will be None. That is the only scenario where we would
414-
# reach here.
415-
continue
416411

417412
if isinstance(condition, MeasurementValue):
418413
control_wires = Wires([m.wires[0] for m in condition.measurements])

pennylane/transforms/optimization/merge_amplitude_embedding.py

Lines changed: 21 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -251,33 +251,27 @@ def _(self, *invals, jaxpr_branches, consts_slices, args_slice):
251251

252252
for const_slice, jaxpr in zip(consts_slices, jaxpr_branches):
253253
consts = invals[const_slice]
254-
if jaxpr is None:
255-
new_jaxprs.append(None)
256-
new_consts_slices.append(slice(0, 0))
257-
else:
258-
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
259-
260-
# Update state so far so collisions with
261-
# newly seen states from the branches continue to be
262-
# detected after the cond
263-
curr_wires |= self.state["visited_wires"]
264-
curr_dynamic_wires_found = self.state["dynamic_wires_found"]
265-
curr_ops_found = self.state["ops_found"]
266-
267-
# Reset state for the next branch so we don't get false positive collisions
268-
# (copy so if state mutates we preserved true initial state)
269-
self.state = {
270-
"visited_wires": copy(initial_wires),
271-
"dynamic_wires_found": initial_dynamic_wires_found,
272-
"ops_found": initial_ops_found,
273-
}
274-
275-
new_jaxprs.append(new_jaxpr.jaxpr)
276-
new_consts.extend(new_jaxpr.consts)
277-
new_consts_slices.append(
278-
slice(end_const_ind, end_const_ind + len(new_jaxpr.consts))
279-
)
280-
end_const_ind += len(new_jaxpr.consts)
254+
new_jaxpr = jaxpr_to_jaxpr(copy(self), jaxpr, consts, *args)
255+
256+
# Update state so far so collisions with
257+
# newly seen states from the branches continue to be
258+
# detected after the cond
259+
curr_wires |= self.state["visited_wires"]
260+
curr_dynamic_wires_found = curr_dynamic_wires_found or self.state["dynamic_wires_found"]
261+
curr_ops_found = curr_ops_found or self.state["ops_found"]
262+
263+
# Reset state for the next branch so we don't get false positive collisions
264+
# (copy so if state mutates we preserved true initial state)
265+
self.state = {
266+
"visited_wires": copy(initial_wires),
267+
"dynamic_wires_found": initial_dynamic_wires_found,
268+
"ops_found": initial_ops_found,
269+
}
270+
271+
new_jaxprs.append(new_jaxpr.jaxpr)
272+
new_consts.extend(new_jaxpr.consts)
273+
new_consts_slices.append(slice(end_const_ind, end_const_ind + len(new_jaxpr.consts)))
274+
end_const_ind += len(new_jaxpr.consts)
281275

282276
# Reset state to all updates from all branches in the cond
283277
self.state = {

tests/capture/test_base_interpreter.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -501,7 +501,8 @@ def f():
501501

502502
jaxpr = jax.make_jaxpr(f)(True)
503503

504-
assert jaxpr.eqns[0].params["jaxpr_branches"][-1] is None # no false branch
504+
false_branch = jaxpr.eqns[0].params["jaxpr_branches"][-1]
505+
assert len(false_branch.eqns) == 0
505506

506507
with qml.queuing.AnnotatedQueue() as q_true:
507508
jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, True)

tests/capture/test_capture_cond.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,9 @@ def f():
409409
assert len(true_fn.outvars) == 0
410410
assert true_fn.eqns[0].primitive == qml.X._primitive # pylint: disable=protected-access
411411

412-
assert jaxpr.eqns[0].params["jaxpr_branches"][-1] is None
412+
false_fn = jaxpr.eqns[0].params["jaxpr_branches"][-1]
413+
assert len(false_fn.eqns) == 0
414+
assert len(false_fn.outvars) == 0
413415

414416

415417
dev = qml.device("default.qubit", wires=3)

0 commit comments

Comments
 (0)