Skip to content

Commit 92828f2

Browse files
rniczhjoeycarterpaul0403dime10
authored
Fix type promotion on branch (#1977)
**Context:** ```python @qml.qjit(autograph=True) @qml.qnode(qml.device('lightning.qubit', wires=1)) def circuit(cond1, cond2): if cond1: res = False elif cond2: res = 0.5 else: res = False return res circuit(False, True) ``` Should return value with type float64 instead of bool **Description of the Change:** Choose the type that match the promoted type. For an example: ``` branch_avals: [[ShapedArray(float64[], weak_type=True)], [ShapedArray(bool[])]] promoted_dtypes: [dtype('float64')] ``` Expected to return matched type `((ShapedArray(float64[], weak_type=True), ...),)` **Benefits:** **Possible Drawbacks:** **Related GitHub Issues:** Fixes #1958 [sc-97050] --------- Co-authored-by: Joey Carter <[email protected]> Co-authored-by: Paul <[email protected]> Co-authored-by: David Ittah <[email protected]>
1 parent 541c97d commit 92828f2

File tree

3 files changed

+137
-34
lines changed

3 files changed

+137
-34
lines changed

doc/releases/changelog-dev.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,9 @@
6060

6161
<h3>Bug fixes 🐛</h3>
6262

63+
* Fix type promotion on conditional branches, where the return values from `cond` should be the promoted one.
64+
[(#1977)](https://github.com/PennyLaneAI/catalyst/pull/1977)
65+
6366
* Fix wrong handling of partitioned shots in the decomposition pass of `measurements_from_samples`.
6467
[(#1981)](https://github.com/PennyLaneAI/catalyst/pull/1981)
6568

frontend/catalyst/api_extensions/control_flow.py

Lines changed: 30 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
HybridOp,
6161
HybridOpRegion,
6262
QRegPromise,
63+
_promote_jaxpr_types,
6364
trace_function,
6465
trace_quantum_operations,
6566
unify_convert_result_types,
@@ -162,7 +163,8 @@ def cond_else():
162163
163164
The conditional function is permitted to also return values.
164165
Any value that is supported by JAX JIT compilation is supported as a return
165-
type.
166+
type. If provided, return types need to be identical or at least promotable across both
167+
branches.
166168
167169
.. code-block:: python
168170
@@ -185,31 +187,22 @@ def conditional_fn():
185187
There are various constraints and restrictions that should be kept in mind
186188
when working with conditionals in Catalyst.
187189
188-
The return values of all branches of :func:`~.cond` must be the same type.
189-
Returning different types, or ommitting a return value in one branch (e.g.,
190+
The return values of all branches of :func:`~.cond` must be the same shape.
191+
Returning different shapes, or ommitting a return value in one branch (e.g.,
190192
returning ``None``) but not in others will result in an error.
191193
194+
However, the return values of all branches of :func:`~.cond` can be different data types.
195+
In this case, the return types will automatically be promoted to the next common larger
196+
type.
197+
192198
>>> @qjit
193199
... def f(x: float):
194200
... @cond(x > 1.5)
195201
... def cond_fn():
196202
... return x ** 2 # float
197203
... @cond_fn.otherwise
198204
... def else_branch():
199-
... return 6 # int
200-
... return cond_fn()
201-
TypeError: Conditional requires consistent return types across all branches, got:
202-
- Branch at index 0: [ShapedArray(float64[], weak_type=True)]
203-
- Branch at index 1: [ShapedArray(int64[], weak_type=True)]
204-
Please specify an else branch if none was specified.
205-
>>> @qjit
206-
... def f(x: float):
207-
... @cond(x > 1.5)
208-
... def cond_fn():
209-
... return x ** 2 # float
210-
... @cond_fn.otherwise
211-
... def else_branch():
212-
... return 6. # float
205+
... return 6 # int (promotable to float)
213206
... return cond_fn()
214207
>>> f(1.5)
215208
Array(6., dtype=float64)
@@ -224,10 +217,9 @@ def conditional_fn():
224217
... def cond_fn():
225218
... return x ** 2
226219
... return cond_fn()
227-
TypeError: Conditional requires consistent return types across all branches, got:
228-
- Branch at index 0: [ShapedArray(float64[], weak_type=True)]
229-
- Branch at index 1: []
230-
Please specify an else branch if none was specified.
220+
TypeError: Conditional requires a consistent return structure across all branches! Got
221+
PyTreeDef(None) and PyTreeDef(*). Please specify an else branch if PyTreeDef(None) was
222+
specified.
231223
232224
>>> @qjit
233225
... def f(x: float):
@@ -774,18 +766,16 @@ def _call_with_quantum_ctx(self):
774766
out_tree = out_sigs[-1].out_tree()
775767
all_consts = [s.out_consts() for s in out_sigs]
776768
out_types = [s.out_type() for s in out_sigs]
777-
# FIXME: We want to perform the result unificaiton here:
778-
# all_jaxprs = [s.out_initial_jaxpr() for s in out_sigs]
779-
# all_noimplouts = [s.num_implicit_outputs() for s in out_sigs]
780-
# _, out_type, _, all_consts = unify_convert_result_types(
781-
# all_jaxprs, all_consts, all_noimplouts
782-
# )
783-
# Unfortunately, we can not do this beacuse some tracers (specifically, the results of
784-
# ``qml.measure``) might not have their source Jaxpr equation yet. Thus, we delay the
785-
# unification until the quantum tracing is done. The consequence of that: we have to guess
786-
# the output type now and if we fail to do so, we might face MLIR type error down the
787-
# pipeline.
769+
770+
# Select the output type of the one with the promoted dtype among all branches
788771
out_type = out_types[-1]
772+
branch_avals = [[aval for aval, _ in branch_out_type] for branch_out_type in out_types]
773+
promoted_dtypes = _promote_jaxpr_types(branch_avals)
774+
775+
out_type = [
776+
(aval.update(dtype=dtype), expl)
777+
for dtype, (aval, expl) in zip(promoted_dtypes, out_type)
778+
]
789779

790780
# Create output tracers in the outer tracing context
791781
out_expanded_classical_tracers = output_type_to_tracers(
@@ -1563,14 +1553,20 @@ def _assert_cond_result_structure(trees: List[PyTreeDef]):
15631553
if tree != expected_tree:
15641554
raise TypeError(
15651555
"Conditional requires a consistent return structure across all branches! "
1566-
f"Got {tree} and {expected_tree}."
1556+
f"Got {tree} and {expected_tree}. Please specify an else branch if PyTreeDef(None) "
1557+
"was specified."
15671558
)
15681559

15691560

15701561
def _assert_cond_result_types(signatures: List[List[AbstractValue]]):
15711562
"""Ensure a consistent type signature across branch results."""
15721563
num_results = len(signatures[0])
1573-
assert all(len(sig) == num_results for sig in signatures), "mismatch: number or results"
1564+
1565+
if not all(len(sig) == num_results for sig in signatures):
1566+
raise TypeError(
1567+
"Conditional requires a consistent number of results across all branches! It might "
1568+
"happen when some of the branch returns dynamic shape and some return constant shape."
1569+
)
15741570

15751571
for i in range(num_results):
15761572
aval_slice = [avals[i] for avals in signatures]

frontend/test/pytest/test_conditionals.py

Lines changed: 104 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from textwrap import dedent
1616

17+
import jax
1718
import jax.numpy as jnp
1819
import numpy as np
1920
import pennylane as qml
@@ -64,6 +65,7 @@ def asline(text):
6465
assert asline(expected) == asline(circuit.jaxpr)
6566

6667

68+
# pylint: disable=too-many-public-methods,too-many-lines
6769
class TestCond:
6870
"""Test suite for the Cond functionality in Catalyst."""
6971

@@ -311,6 +313,108 @@ def cond_else():
311313

312314
assert 0 == circuit()
313315

316+
def test_branch_multi_return_type_unification_qjit_2(self):
317+
"""Test that unification happens before the results of the cond primitve is available."""
318+
319+
@qjit
320+
def circuit(cond1, cond2):
321+
@cond(cond1)
322+
def cond_fn():
323+
return False
324+
325+
@cond_fn.else_if(cond2)
326+
def cond_fn_2():
327+
return 0.5
328+
329+
@cond_fn.otherwise
330+
def cond_fn_3():
331+
return False
332+
333+
r = cond_fn()
334+
assert r.dtype is jnp.dtype(
335+
"float64" if jax.config.values["jax_enable_x64"] else "float32"
336+
)
337+
return r
338+
339+
assert 0.5 == circuit(False, True)
340+
341+
def test_branch_multi_return_type_unification_qjit_3(self):
342+
"""Test that unification happens before the results of the cond primitve is available."""
343+
344+
@qjit
345+
def circuit(cond1, cond2):
346+
@cond(cond1)
347+
def cond_fn():
348+
return False
349+
350+
@cond_fn.else_if(cond2)
351+
def cond_fn_2():
352+
return False
353+
354+
@cond_fn.otherwise
355+
def cond_fn_3():
356+
return 0.5
357+
358+
r = cond_fn()
359+
assert r.dtype is jnp.dtype(
360+
"float64" if jax.config.values["jax_enable_x64"] else "float32"
361+
)
362+
return r
363+
364+
assert 0.0 == circuit(False, True)
365+
366+
def test_branch_multi_return_type_unification_qjit_4(self):
367+
"""Test that unification happens before the results of the cond primitve is available."""
368+
369+
@qjit
370+
def circuit(cond1, cond2):
371+
@cond(cond1)
372+
def cond_fn():
373+
return {0: True, 1: 0.5}
374+
375+
@cond_fn.else_if(cond2)
376+
def cond_fn_2():
377+
return {0: 0.7, 1: True}
378+
379+
@cond_fn.otherwise
380+
def cond_fn_3():
381+
return {0: True, 1: False}
382+
383+
r = cond_fn()
384+
expected_dtype = jnp.dtype(
385+
"float64" if jax.config.values["jax_enable_x64"] else "float32"
386+
)
387+
assert all(v.dtype is expected_dtype for _, v in r.items())
388+
return r
389+
390+
assert {0: 0.7, 1: 1.0} == circuit(False, True)
391+
392+
def test_qnode_cond_inconsistent_return_types(self, backend):
393+
"""Test that catalyst raises an error when the conditional has inconsistent return types."""
394+
395+
@qjit
396+
@qml.qnode(qml.device(backend, wires=4))
397+
def f(flag, sz):
398+
a = jnp.ones([sz], dtype=float)
399+
b = jnp.zeros([3], dtype=float)
400+
401+
@cond(flag)
402+
def case():
403+
return a
404+
405+
@case.otherwise
406+
def case():
407+
return b
408+
409+
c = case()
410+
return c
411+
412+
with pytest.raises(
413+
TypeError,
414+
match="Conditional requires a consistent number of results across all branches",
415+
):
416+
f(True, 3)
417+
314418
@pytest.mark.xfail(
315419
reason="Inability to apply Jax transformations before the quantum traing is complete"
316420
)

0 commit comments

Comments
 (0)