6060 HybridOp ,
6161 HybridOpRegion ,
6262 QRegPromise ,
63+ _promote_jaxpr_types ,
6364 trace_function ,
6465 trace_quantum_operations ,
6566 unify_convert_result_types ,
@@ -162,7 +163,8 @@ def cond_else():
162163
163164 The conditional function is permitted to also return values.
164165 Any value that is supported by JAX JIT compilation is supported as a return
165- type.
166+ type. If provided, return types need to be identical or at least promotable across both
167+ branches.
166168
167169 .. code-block:: python
168170
@@ -185,31 +187,22 @@ def conditional_fn():
185187 There are various constraints and restrictions that should be kept in mind
186188 when working with conditionals in Catalyst.
187189
188- The return values of all branches of :func:`~.cond` must be the same type .
189- Returning different types , or ommitting a return value in one branch (e.g.,
190+ The return values of all branches of :func:`~.cond` must be the same shape .
191+ Returning different shapes , or ommitting a return value in one branch (e.g.,
190192 returning ``None``) but not in others will result in an error.
191193
194+ However, the return values of all branches of :func:`~.cond` can be different data types.
195+ In this case, the return types will automatically be promoted to the next common larger
196+ type.
197+
192198 >>> @qjit
193199 ... def f(x: float):
194200 ... @cond(x > 1.5)
195201 ... def cond_fn():
196202 ... return x ** 2 # float
197203 ... @cond_fn.otherwise
198204 ... def else_branch():
199- ... return 6 # int
200- ... return cond_fn()
201- TypeError: Conditional requires consistent return types across all branches, got:
202- - Branch at index 0: [ShapedArray(float64[], weak_type=True)]
203- - Branch at index 1: [ShapedArray(int64[], weak_type=True)]
204- Please specify an else branch if none was specified.
205- >>> @qjit
206- ... def f(x: float):
207- ... @cond(x > 1.5)
208- ... def cond_fn():
209- ... return x ** 2 # float
210- ... @cond_fn.otherwise
211- ... def else_branch():
212- ... return 6. # float
205+ ... return 6 # int (promotable to float)
213206 ... return cond_fn()
214207 >>> f(1.5)
215208 Array(6., dtype=float64)
@@ -224,10 +217,9 @@ def conditional_fn():
224217 ... def cond_fn():
225218 ... return x ** 2
226219 ... return cond_fn()
227- TypeError: Conditional requires consistent return types across all branches, got:
228- - Branch at index 0: [ShapedArray(float64[], weak_type=True)]
229- - Branch at index 1: []
230- Please specify an else branch if none was specified.
220+ TypeError: Conditional requires a consistent return structure across all branches! Got
221+ PyTreeDef(None) and PyTreeDef(*). Please specify an else branch if PyTreeDef(None) was
222+ specified.
231223
232224 >>> @qjit
233225 ... def f(x: float):
@@ -774,18 +766,16 @@ def _call_with_quantum_ctx(self):
774766 out_tree = out_sigs [- 1 ].out_tree ()
775767 all_consts = [s .out_consts () for s in out_sigs ]
776768 out_types = [s .out_type () for s in out_sigs ]
777- # FIXME: We want to perform the result unificaiton here:
778- # all_jaxprs = [s.out_initial_jaxpr() for s in out_sigs]
779- # all_noimplouts = [s.num_implicit_outputs() for s in out_sigs]
780- # _, out_type, _, all_consts = unify_convert_result_types(
781- # all_jaxprs, all_consts, all_noimplouts
782- # )
783- # Unfortunately, we can not do this beacuse some tracers (specifically, the results of
784- # ``qml.measure``) might not have their source Jaxpr equation yet. Thus, we delay the
785- # unification until the quantum tracing is done. The consequence of that: we have to guess
786- # the output type now and if we fail to do so, we might face MLIR type error down the
787- # pipeline.
769+
770+ # Select the output type of the one with the promoted dtype among all branches
788771 out_type = out_types [- 1 ]
772+ branch_avals = [[aval for aval , _ in branch_out_type ] for branch_out_type in out_types ]
773+ promoted_dtypes = _promote_jaxpr_types (branch_avals )
774+
775+ out_type = [
776+ (aval .update (dtype = dtype ), expl )
777+ for dtype , (aval , expl ) in zip (promoted_dtypes , out_type )
778+ ]
789779
790780 # Create output tracers in the outer tracing context
791781 out_expanded_classical_tracers = output_type_to_tracers (
@@ -1563,14 +1553,20 @@ def _assert_cond_result_structure(trees: List[PyTreeDef]):
15631553 if tree != expected_tree :
15641554 raise TypeError (
15651555 "Conditional requires a consistent return structure across all branches! "
1566- f"Got { tree } and { expected_tree } ."
1556+ f"Got { tree } and { expected_tree } . Please specify an else branch if PyTreeDef(None) "
1557+ "was specified."
15671558 )
15681559
15691560
15701561def _assert_cond_result_types (signatures : List [List [AbstractValue ]]):
15711562 """Ensure a consistent type signature across branch results."""
15721563 num_results = len (signatures [0 ])
1573- assert all (len (sig ) == num_results for sig in signatures ), "mismatch: number or results"
1564+
1565+ if not all (len (sig ) == num_results for sig in signatures ):
1566+ raise TypeError (
1567+ "Conditional requires a consistent number of results across all branches! It might "
1568+ "happen when some of the branch returns dynamic shape and some return constant shape."
1569+ )
15741570
15751571 for i in range (num_results ):
15761572 aval_slice = [avals [i ] for avals in signatures ]
0 commit comments