diff --git a/frontend/catalyst/from_plxpr/control_flow.py b/frontend/catalyst/from_plxpr/control_flow.py index 37de2ed5d..bd264849f 100644 --- a/frontend/catalyst/from_plxpr/control_flow.py +++ b/frontend/catalyst/from_plxpr/control_flow.py @@ -26,25 +26,70 @@ from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter -from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder +from catalyst.from_plxpr.qubit_handler import ( + QubitHandler, + QubitIndexRecorder, + _get_dynamically_allocated_qregs, +) from catalyst.jax_extras import jaxpr_pad_consts from catalyst.jax_primitives import cond_p, for_p, while_p -def _calling_convention(interpreter, closed_jaxpr, *args_plus_qreg): - # The last arg is the scope argument for the body jaxpr - *args, qreg = args_plus_qreg +def _calling_convention( + interpreter, closed_jaxpr, *args_plus_qregs, outer_dynqreg_handlers=(), return_qreg=True +): + # Arg structure (all args are tracers, since this function is to be `make_jaxpr`'d): + # Regular args, then dynamically allocated qregs, then global qreg + # TODO: merge dynamically allocaed qregs into regular args? + # But this is tricky, since qreg arguments need all the SSA value semantics conversion infra + # and are different from the regular plain arguments. + *args_plus_dynqregs, global_qreg = args_plus_qregs + num_dynamic_alloced_qregs = len(outer_dynqreg_handlers) + args, dynalloced_qregs = ( + args_plus_dynqregs[: len(args_plus_dynqregs) - num_dynamic_alloced_qregs], + args_plus_dynqregs[len(args_plus_dynqregs) - num_dynamic_alloced_qregs :], + ) # Launch a new interpreter for the body region # A new interpreter's root qreg value needs a new recorder converter = copy(interpreter) converter.qubit_index_recorder = QubitIndexRecorder() - init_qreg = QubitHandler(qreg, converter.qubit_index_recorder) + init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder) converter.init_qreg = init_qreg - # pylint: disable-next=cell-var-from-loop + # add dynamic qregs to recorder + qreg_map = {} + dyn_qreg_handlers = [] + for dyn_qreg, outer_dynqreg_handler in zip( + dynalloced_qregs, outer_dynqreg_handlers, strict=True + ): + dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder) + dyn_qreg_handlers.append(dyn_qreg_handler) + + # plxpr global wire index does not change across scopes + # So scope arg dynamic qregs need to have the same root hash as their corresponding + # qreg tracers outside + dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash + + # Each qreg argument of the subscope corresponds to a qreg from the outer scope + qreg_map[outer_dynqreg_handler] = dyn_qreg_handler + + # The new interpreter's recorder needs to be updated to include the qreg args + # of this scope, instead of the outer qregs + if qreg_map: + for k, outer_dynqreg_handler in interpreter.qubit_index_recorder.map.items(): + converter.qubit_index_recorder[k] = qreg_map[outer_dynqreg_handler] + retvals = converter(closed_jaxpr, *args) + if not return_qreg: + return retvals + init_qreg.insert_all_dangling_qubits() + + # Return all registers + for dyn_qreg_handler in dyn_qreg_handlers: + dyn_qreg_handler.insert_all_dangling_qubits() + retvals.append(dyn_qreg_handler.get()) return *retvals, converter.init_qreg.get() @@ -90,7 +135,18 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): """Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive""" args = plxpr_invals[args_slice] self.init_qreg.insert_all_dangling_qubits() - args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args + + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + plxpr_invals, self.qubit_index_recorder, self.init_qreg + ) + + # Add the qregs to the args + args_plus_qreg = [ + *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], + self.init_qreg.get(), + ] + converted_jaxpr_branches = [] all_consts = [] @@ -103,7 +159,9 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): converted_jaxpr_branch = None closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts) - f = partial(_calling_convention, self, closed_jaxpr) + f = partial( + _calling_convention, self, closed_jaxpr, outer_dynqreg_handlers=dynalloced_qregs + ) converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg) all_consts += converted_jaxpr_branch.consts @@ -112,6 +170,8 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]] # Build Catalyst compatible input values + # strip global wire indices of dynamic wires + all_consts = tuple(const for const in all_consts if const not in dynalloced_wire_global_indices) cond_invals = [*predicate, *all_consts, *args_plus_qreg] # Perform the binding @@ -121,9 +181,12 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice): nimplicit_outputs=None, ) - # We assume the last output value is the returned qreg. + # Output structure: + # First a list of dynamically allocated qregs, then the global qreg # Update the current qreg and remove it from the output values. self.init_qreg.set(outvals.pop()) + for dyn_qreg in reversed(dynalloced_qregs): + dyn_qreg.set(outvals.pop()) # Return only the output values that match the plxpr output values return outvals @@ -193,9 +256,15 @@ def handle_for_loop( # Add the iteration start and the qreg to the args self.init_qreg.insert_all_dangling_qubits() + + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + plxpr_invals, self.qubit_index_recorder, self.init_qreg + ) + start_plus_args_plus_qreg = [ start, *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], self.init_qreg.get(), ] @@ -203,7 +272,12 @@ def handle_for_loop( jaxpr = ClosedJaxpr(jaxpr_body_fn, consts) - f = partial(_calling_convention, self, jaxpr) + f = partial( + _calling_convention, + self, + jaxpr, + outer_dynqreg_handlers=dynalloced_qregs, + ) converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg) converted_closed_jaxpr_branch = ClosedJaxpr( @@ -211,7 +285,9 @@ def handle_for_loop( ) # Build Catalyst compatible input values + # strip global wire indices of dynamic wires new_consts = converted_jaxpr_branch.consts + new_consts = tuple(const for const in new_consts if const not in dynalloced_wire_global_indices) for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg] # Config additional for loop settings @@ -227,10 +303,14 @@ def handle_for_loop( preserve_dimensions=True, ) - # We assume the last output value is the returned qreg. + # Output structure: + # First a list of dynamically allocated qregs, then the global qreg # Update the current qreg and remove it from the output values. self.init_qreg.set(outvals.pop()) + for dyn_qreg in reversed(dynalloced_qregs): + dyn_qreg.set(outvals.pop()) + # Return only the output values that match the plxpr output values return outvals @@ -289,14 +369,21 @@ def handle_while_loop( ): """Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive""" self.init_qreg.insert_all_dangling_qubits() + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + plxpr_invals, self.qubit_index_recorder, self.init_qreg + ) consts_body = plxpr_invals[body_slice] consts_cond = plxpr_invals[cond_slice] args = plxpr_invals[args_slice] - args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args + args_plus_qreg = [ + *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], + self.init_qreg.get(), + ] # Add the qreg to the args jaxpr = ClosedJaxpr(jaxpr_body_fn, consts_body) - f = partial(_calling_convention, self, jaxpr) + f = partial(_calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs) converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr converted_body_closed_jaxpr_branch = ClosedJaxpr( @@ -307,30 +394,22 @@ def handle_while_loop( # We need to be able to handle arbitrary plxpr here. # But we want to be able to create a state where: # * We do not pass the quantum register as an argument. - # So let's just remove the quantum register here at the end - jaxpr = ClosedJaxpr(jaxpr_cond_fn, consts_cond) - def remove_qreg(*args_plus_qreg): - # The last arg is the scope argument for the body jaxpr - *args, qreg = args_plus_qreg - - # Launch a new interpreter for the body region - # A new interpreter's root qreg value needs a new recorder - converter = copy(self) - converter.qubit_index_recorder = QubitIndexRecorder() - init_qreg = QubitHandler(qreg, converter.qubit_index_recorder) - converter.init_qreg = init_qreg - - return converter(jaxpr, *args) + f_remove_qreg = partial( + _calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs, return_qreg=False + ) - converted_cond_jaxpr_branch = jax.make_jaxpr(remove_qreg)(*args_plus_qreg).jaxpr + converted_cond_jaxpr_branch = jax.make_jaxpr(f_remove_qreg)(*args_plus_qreg).jaxpr converted_cond_closed_jaxpr_branch = ClosedJaxpr( convert_constvars_jaxpr(converted_cond_jaxpr_branch), () ) # Build Catalyst compatible input values + consts_body = tuple( + const for const in consts_body if const not in dynalloced_wire_global_indices + ) while_loop_invals = [*consts_cond, *consts_body, *args_plus_qreg] # Perform the binding @@ -348,5 +427,8 @@ def remove_qreg(*args_plus_qreg): # Update the current qreg and remove it from the output values. self.init_qreg.set(outvals.pop()) + for dyn_qreg in reversed(dynalloced_qregs): + dyn_qreg.set(outvals.pop()) + # Return only the output values that match the plxpr output values return outvals diff --git a/frontend/catalyst/from_plxpr/from_plxpr.py b/frontend/catalyst/from_plxpr/from_plxpr.py index 5dbf302b8..081f4386f 100644 --- a/frontend/catalyst/from_plxpr/from_plxpr.py +++ b/frontend/catalyst/from_plxpr/from_plxpr.py @@ -52,6 +52,7 @@ from catalyst.from_plxpr.qubit_handler import ( QubitHandler, QubitIndexRecorder, + _get_dynamically_allocated_qregs, get_in_qubit_values, is_dynamically_allocated_wire, ) @@ -642,16 +643,6 @@ def handle_subroutine(self, *args, **kwargs): Transform the subroutine from PLxPR into JAXPR with quantum primitives. """ - if any(is_dynamically_allocated_wire(arg) for arg in args): - raise NotImplementedError( - textwrap.dedent( - """ - Dynamically allocated wires in a parent scope cannot be used in a child - scope yet. Please consider dynamical allocation inside the child scope. - """ - ) - ) - backup = dict(self.init_qreg) self.init_qreg.insert_all_dangling_qubits() @@ -659,20 +650,73 @@ def handle_subroutine(self, *args, **kwargs): plxpr = kwargs["jaxpr"] transformed = self.subroutine_cache.get(plxpr) - def wrapper(qreg, *args): - # Launch a new interpreter for the new subroutine region + dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs( + args, self.qubit_index_recorder, self.init_qreg + ) + + # Convert global wire indices into local indices + new_args = () + wire_label_arg_to_tracer_arg_index = {} + for i, arg in enumerate(args): + if arg in dynalloced_wire_global_indices: + wire_label_arg_to_tracer_arg_index[arg] = i + new_args += (self.qubit_index_recorder[arg].global_index_to_local_index(arg),) + else: + new_args += (arg,) + + def wrapper(*qregs_plus_args): + global_qreg, *dynqregs_plus_args = qregs_plus_args + num_dynamic_alloced_qregs = len(dynalloced_qregs) + _dynalloced_qregs, args = ( + dynqregs_plus_args[:num_dynamic_alloced_qregs], + dynqregs_plus_args[num_dynamic_alloced_qregs:], + ) + + # Launch a new interpreter for the body region # A new interpreter's root qreg value needs a new recorder converter = copy(self) converter.qubit_index_recorder = QubitIndexRecorder() - init_qreg = QubitHandler(qreg, converter.qubit_index_recorder) + init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder) converter.init_qreg = init_qreg + # add dynamic qregs to recorder + qreg_map = {} + dyn_qreg_handlers = [] + for dyn_qreg, outer_dynqreg_handler, global_wire_index in zip( + _dynalloced_qregs, dynalloced_qregs, dynalloced_wire_global_indices, strict=True + ): + dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder) + dyn_qreg_handlers.append(dyn_qreg_handler) + + # plxpr global wire index does not change across scopes + # So scope arg dynamic qregs need to have the same root hash as their corresponding + # qreg tracers outside + dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash + + # Each qreg argument of the subscope corresponds to a qreg from the outer scope + qreg_map[args[wire_label_arg_to_tracer_arg_index[global_wire_index]]] = dyn_qreg_handler + + # The new interpreter's recorder needs to be updated to include the qreg args + # of this scope, instead of the outer qregs + for arg in args: + if arg in qreg_map: + converter.qubit_index_recorder[arg] = qreg_map[arg] + retvals = converter(plxpr, *args) - converter.init_qreg.insert_all_dangling_qubits() + + init_qreg.insert_all_dangling_qubits() + + # Return all registers + for dyn_qreg_handler in reversed(dyn_qreg_handlers): + dyn_qreg_handler.insert_all_dangling_qubits() + retvals.insert(0, dyn_qreg_handler.get()) + return converter.init_qreg.get(), *retvals if not transformed: - converted_closed_jaxpr_branch = jax.make_jaxpr(wrapper)(self.init_qreg.get(), *args) + converted_closed_jaxpr_branch = jax.make_jaxpr(wrapper)( + self.init_qreg.get(), *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], *args + ) self.subroutine_cache[plxpr] = converted_closed_jaxpr_branch else: converted_closed_jaxpr_branch = transformed @@ -681,12 +725,13 @@ def wrapper(qreg, *args): # is just pjit_p with a different name. vals_out = quantum_subroutine_p.bind( self.init_qreg.get(), - *args, + *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], + *new_args, jaxpr=converted_closed_jaxpr_branch, - in_shardings=(UNSPECIFIED, *kwargs["in_shardings"]), - out_shardings=(UNSPECIFIED, *kwargs["out_shardings"]), - in_layouts=(None, *kwargs["in_layouts"]), - out_layouts=(None, *kwargs["out_layouts"]), + in_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["in_shardings"]), + out_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["out_shardings"]), + in_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["in_layouts"]), + out_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["out_layouts"]), donated_invars=kwargs["donated_invars"], ctx_mesh=kwargs["ctx_mesh"], name=kwargs["name"], @@ -696,7 +741,9 @@ def wrapper(qreg, *args): ) self.init_qreg.set(vals_out[0]) - vals_out = vals_out[1:] + for i, dyn_qreg in enumerate(dynalloced_qregs): + dyn_qreg.set(vals_out[i + 1]) + vals_out = vals_out[len(dynalloced_qregs) + 1 :] for orig_wire in backup.keys(): self.init_qreg.extract(orig_wire) diff --git a/frontend/catalyst/from_plxpr/qubit_handler.py b/frontend/catalyst/from_plxpr/qubit_handler.py index f9adb55e6..1739b3841 100644 --- a/frontend/catalyst/from_plxpr/qubit_handler.py +++ b/frontend/catalyst/from_plxpr/qubit_handler.py @@ -68,8 +68,6 @@ qubit SSA values on its wires? """ -import textwrap - from catalyst.jax_extras import DynamicJaxprTracer from catalyst.jax_primitives import AbstractQbit, AbstractQreg, qextract_p, qinsert_p from catalyst.utils.exceptions import CompileError @@ -422,21 +420,6 @@ def get_in_qubit_values( if not qubit_index_recorder.contains(w): # First time the global wire index w is encountered # Need to extract from fallback qreg - # TODO: this can now only be from the global qreg, because right now in from_plxpr - # conversion, subscopes (control flow, adjoint, ...) can only take in the global - # qreg as the final scope argument. They cannot take an arbitrary number of qreg - # values yet. - # Supporting multiple registers requires refactoring the from_plxpr conversion's - # implementation. - if is_dynamically_allocated_wire(w): - raise NotImplementedError( - textwrap.dedent( - """ - Dynamically allocated wires in a parent scope cannot be used in a child - scope yet. Please consider dynamical allocation inside the child scope. - """ - ) - ) in_qubits.append(fallback_qreg[fallback_qreg.global_index_to_local_index(w)]) in_qregs.append(fallback_qreg) @@ -446,3 +429,28 @@ def get_in_qubit_values( in_qubits.append(in_qreg[in_qreg.global_index_to_local_index(w)]) return in_qregs, in_qubits + + +def _get_dynamically_allocated_qregs(plxpr_invals, qubit_index_recorder, init_qreg): + """ + Get the potential dynamically allocated register values that are visible to a jaxpr. + + Note that dynamically allocated wires have their qreg tracer's id as the global wire index + so the sub jaxpr takes that id in as a "const", since it is clousure from the target wire + of gates/measurements/... + We need to remove that const, so we also let this util return these global indices. + """ + dynalloced_qregs = [] + dynalloced_wire_global_indices = [] + for inval in plxpr_invals: + if ( + isinstance(inval, int) + and qubit_index_recorder.contains(inval) + and qubit_index_recorder[inval] is not init_qreg + ): + dyn_qreg = qubit_index_recorder[inval] + dyn_qreg.insert_all_dangling_qubits() + dynalloced_qregs.append(dyn_qreg) + dynalloced_wire_global_indices.append(inval) + + return dynalloced_qregs, dynalloced_wire_global_indices diff --git a/frontend/test/lit/test_dynamic_qubit_allocation.py b/frontend/test/lit/test_dynamic_qubit_allocation.py index bb94a1bc6..c6b4a001e 100644 --- a/frontend/test/lit/test_dynamic_qubit_allocation.py +++ b/frontend/test/lit/test_dynamic_qubit_allocation.py @@ -21,10 +21,10 @@ import pennylane as qml from catalyst import qjit -from catalyst.jax_primitives import qalloc_p, qdealloc_qb_p, qextract_p +from catalyst.jax_primitives import qalloc_p, qdealloc_qb_p, qextract_p, subroutine -@qjit +@qjit(target="mlir") def test_single_qubit_dealloc(): """ Unit test for the single qubit dealloc primitive's lowerings. @@ -95,7 +95,7 @@ def test_basic_dynalloc(): print(test_basic_dynalloc.mlir) -@qjit(autograph=True) +@qjit(autograph=True, target="mlir") @qml.qnode(qml.device("lightning.qubit", wires=3)) def test_measure_with_reset(): """ @@ -125,4 +125,175 @@ def test_measure_with_reset(): print(test_measure_with_reset.mlir) +@qjit(autograph=True, target="mlir") +@qml.qnode(qml.device("lightning.qubit", wires=2)) +def test_pass_reg_into_forloop(): + """ + Test using a dynamically allocated resgister from inside a subscope. + """ + + # CHECK: [[global_reg:%.+]] = quantum.alloc( 2) + # CHECK: [[dyn_reg:%.+]] = quantum.alloc( 1) + # CHECK: [[for_out:%.+]]:2 = scf.for %arg0 = {{.+}} to {{.+}} step {{.+}} iter_args + # CHECK-SAME: (%arg1 = [[dyn_reg]], %arg2 = [[global_reg]]) -> (!quantum.reg, !quantum.reg) { + # CHECK: [[x_in:%.+]] = quantum.extract %arg1[ 0] + # CHECK: [[x_out:%.+]] = quantum.custom "PauliX"() [[x_in]] + # CHECK: [[cnot_in:%.+]] = quantum.extract %arg2[ 0] + # CHECK: [[cnot_out:%.+]]:2 = quantum.custom "CNOT"() [[x_out]], [[cnot_in]] + # CHECK: [[global_reg_yield:%.+]] = quantum.insert %arg2[ 0], [[cnot_out]]#1 + # CHECK: [[dyn_reg_yield:%.+]] = quantum.insert %arg1[ 0], [[cnot_out]]#0 + # CHECK: scf.yield [[dyn_reg_yield]], [[global_reg_yield]] : !quantum.reg, !quantum.reg + # CHECK: quantum.dealloc [[for_out]]#0 : !quantum.reg + + with qml.allocate(1) as q: + for _ in range(3): + qml.X(wires=q[0]) + qml.CNOT(wires=[q[0], 0]) + + # CHECK: [[global_bit0:%.+]] = quantum.extract [[for_out]]#1[ 0] + # CHECK: [[global_bit1:%.+]] = quantum.extract [[for_out]]#1[ 1] + # CHECK: [[obs:%.+]] = quantum.compbasis qubits [[global_bit0]], [[global_bit1]] : !quantum.obs + # CHECK: {{.+}} = quantum.probs [[obs]] : tensor<4xf64> + return qml.probs(wires=[0, 1]) + + +print(test_pass_reg_into_forloop.mlir) + + +@qjit(autograph=True, target="mlir") +@qml.qnode(qml.device("lightning.qubit", wires=3)) +def test_pass_multiple_regs_into_forloop(): + """ + Test using multiple dynamically allocated resgisters from inside a subscope. + """ + + # CHECK: [[global_reg:%.+]] = quantum.alloc( 3) + # CHECK: [[q1:%.+]] = quantum.alloc( 1) + # CHECK: [[q2:%.+]] = quantum.alloc( 2) + # CHECK: [[for_out:%.+]]:3 = scf.for %arg0 = {{.+}} to {{.+}} step {{.+}} iter_args + # CHECK-SAME: (%arg1 = [[q1]], %arg2 = [[q2]], %arg3 = [[global_reg]]) + # CHECK-SAME: -> (!quantum.reg, !quantum.reg, !quantum.reg) { + # CHECK: [[q1_0:%.+]] = quantum.extract %arg1[ 0] + # CHECK: [[glob_0:%.+]] = quantum.extract %arg3[ 0] + # CHECK: [[cnot_out0:%.+]]:2 = quantum.custom "CNOT"() [[q1_0]], [[glob_0]] + # CHECK: [[q2_1:%.+]] = quantum.extract %arg2[ 1] + # CHECK: [[glob_1:%.+]] = quantum.extract %arg3[ 1] + # CHECK: [[cnot_out1:%.+]]:2 = quantum.custom "CNOT"() [[q2_1]], [[glob_1]] + # CHECK: [[glob_ins:%.+]] = quantum.insert %arg3[ 0], [[cnot_out0]]#1 + # CHECK: [[glob_yield:%.+]] = quantum.insert [[glob_ins]][ 1], [[cnot_out1]]#1 + # CHECK: [[q1_yield:%.+]] = quantum.insert %arg1[ 0], [[cnot_out0]]#0 + # CHECK: [[q2_yield:%.+]] = quantum.insert %arg2[ 1], [[cnot_out1]]#0 + # CHECK: scf.yield [[q1_yield]], [[q2_yield]], [[glob_yield]] + # CHECK-SAME: : !quantum.reg, !quantum.reg, !quantum.reg + # CHECK: quantum.dealloc [[for_out]]#1 : !quantum.reg + # CHECK: quantum.dealloc [[for_out]]#0 : !quantum.reg + + with qml.allocate(1) as q1: + with qml.allocate(2) as q2: + for _ in range(3): + qml.CNOT(wires=[q1[0], 0]) + qml.CNOT(wires=[q2[1], 1]) + + return qml.probs(wires=[0, 1]) + + +print(test_pass_multiple_regs_into_forloop.mlir) + + +@qjit(autograph=True, target="mlir") +@qml.qnode(qml.device("lightning.qubit", wires=2)) +def test_pass_multiple_regs_into_whileloop(N: int): + """ + Test using multiple dynamically allocated resgisters from inside a while loop. + """ + + # CHECK: [[global_reg:%.+]] = quantum.alloc( 2) + # CHECK: [[q1:%.+]] = quantum.alloc( 1) + # CHECK: [[q2:%.+]] = quantum.alloc( 4) + # CHECK: [[while_out:%.+]]:4 = scf.while (%arg1 = {{%.+}}, %arg2 = [[q1]], %arg3 = [[q2]], + # CHECK-SAME: %arg4 = [[global_reg]]) : (tensor, !quantum.reg, !quantum.reg, !quantum.reg) + # CHECK-SAME: -> (tensor, !quantum.reg, !quantum.reg, !quantum.reg) { + # CHECK: stablehlo.compare LT, %arg1, %arg0 + # CHECK: scf.condition({{%.+}}) %arg1, %arg2, %arg3, %arg4 + # CHECK: } do { + # CHECK: ^bb0(%arg1: tensor, %arg2: !quantum.reg, %arg3: !quantum.reg, %arg4: !quantum.reg + # CHECK: [[q1_0:%.+]] = quantum.extract %arg2[ 0] + # CHECK: [[glob_1:%.+]] = quantum.extract %arg4[ 1] + # CHECK: [[cnot_out0:%.+]]:2 = quantum.custom "CNOT"() [[q1_0]], [[glob_1]] + # CHECK: [[q2_0:%.+]] = quantum.extract %arg3[ 0] + # CHECK: [[cnot_out1:%.+]]:2 = quantum.custom "CNOT"() [[q2_0]], [[cnot_out0]]#1 + # CHECK: [[i:%.+]] = stablehlo.add %arg1, {{%.+}} + # CHECK: [[glob_yield:%.+]] = quantum.insert %arg4[ 1], [[cnot_out1]]#1 + # CHECK: [[q1_yield:%.+]] = quantum.insert %arg2[ 0], [[cnot_out0]]#0 + # CHECK: [[q2_yield:%.+]] = quantum.insert %arg3[ 0], [[cnot_out1]]#0 + # CHECK: scf.yield [[i]], [[q1_yield]], [[q2_yield]], [[glob_yield]] + # CHECK: } + # CHECK: quantum.dealloc [[while_out]]#2 + # CHECK: quantum.dealloc [[while_out]]#1 + + i = 0 + with qml.allocate(1) as q1: + with qml.allocate(4) as q2: + while i < N: + qml.CNOT(wires=[q1[0], 1]) + qml.CNOT(wires=[q2[0], 1]) + i += 1 + + return qml.probs(wires=[0, 1]) + + +print(test_pass_multiple_regs_into_whileloop.mlir) + + +def test_quantum_subroutine(): + """ + Test passing dynamically allocated wires into a quantum subroutine. + """ + + @subroutine + def flip(w1, w2, theta): + qml.X(w1) + qml.X(w2) + qml.ctrl(qml.RX, (w1, w2))(theta, wires=0) + + # CHECK: [[angle:%.+]] = stablehlo.constant dense<1.230000e+00> + # CHECK: [[one:%.+]] = stablehlo.constant dense<1> + # CHECK: [[zero:%.+]] = stablehlo.constant dense<0> + # CHECK: [[global_qreg:%.+]] = quantum.alloc( 1) + # CHECK: [[q1:%.+]] = quantum.alloc( 2) + # CHECK: [[q2:%.+]] = quantum.alloc( 3) + # CHECK: {{%.+}}:3 = call @flip([[global_qreg]], [[q1]], [[q2]], [[zero]], [[one]], [[angle]]) + # CHECK-SAME: (!quantum.reg, !quantum.reg, !quantum.reg, tensor, tensor, tensor) + # CHECK-SAME: -> (!quantum.reg, !quantum.reg, !quantum.reg) + + @qjit(target="mlir") + @qml.qnode(qml.device("lightning.qubit", wires=1)) + def circuit(): + with qml.allocate(2) as q1: + with qml.allocate(3) as q2: + flip(q1[0], q2[1], 1.23) + return qml.probs(wires=[0]) + + # CHECK: func.func private @flip( + # CHECK: [[zero:%.+]] = tensor.extract %arg3[] + # CHECK: [[q1_0:%.+]] = quantum.extract %arg1[[[zero]]] + # CHECK: [[x1_out:%.+]] = quantum.custom "PauliX"() [[q1_0]] + # CHECK: [[one:%.+]] = tensor.extract %arg4[] + # CHECK: [[q2_1:%.+]] = quantum.extract %arg2[[[one]]] + # CHECK: [[x2_out:%.+]] = quantum.custom "PauliX"() [[q2_1]] + # CHECK: [[glob_0:%.+]] = quantum.extract %arg0[ 0] + # CHECK: [[angle:%.+]] = tensor.extract %arg5[] + # CHECK: [[rx_out:%.+]], [[rx_ctrl_out:%.+]]:2 = quantum.custom "RX"([[angle]]) [[glob_0]] + # CHECK-SAME: ctrls([[x1_out]], [[x2_out]]) + # CHECK: [[glob_ret:%.+]] = quantum.insert %arg0[ 0], [[rx_out]] + # CHECK: [[q2_ret:%.+]] = quantum.insert %arg2[{{%.+}}], [[rx_ctrl_out]]#1 + # CHECK: [[q1_ret:%.+]] = quantum.insert %arg1[{{%.+}}], [[rx_ctrl_out]]#0 + # CHECK: return [[glob_ret]], [[q1_ret]], [[q2_ret]] : !quantum.reg, !quantum.reg, !quantum.reg + + print(circuit.mlir) + + +test_quantum_subroutine() + + qml.capture.disable() diff --git a/frontend/test/pytest/test_dynamic_qubit_allocation.py b/frontend/test/pytest/test_dynamic_qubit_allocation.py index c1d62544a..85a3beb45 100644 --- a/frontend/test/pytest/test_dynamic_qubit_allocation.py +++ b/frontend/test/pytest/test_dynamic_qubit_allocation.py @@ -235,6 +235,31 @@ def circuit(c): assert np.allclose(expected, observed) +@pytest.mark.usefixtures("use_capture") +@pytest.mark.parametrize("cond, expected", [(True, [0, 1, 0, 0]), (False, [1, 0, 0, 0])]) +def test_dynamic_wire_alloc_cond_outside(cond, expected, backend): + """ + Test passing dynamically allocated wires into a cond. + """ + + @qjit(autograph=True) + @qml.qnode(qml.device(backend, wires=2)) + def circuit(c): + with qml.allocate(1) as q1: + with qml.allocate(1) as q2: + qml.X(q1[0]) + if c: + qml.CNOT(wires=[q1[0], 1]) # |01> + else: + qml.CNOT(wires=[q2[0], 1]) # |00> + + return qml.probs(wires=[0, 1]) + + observed = circuit(cond) + + assert np.allclose(expected, observed) + + @pytest.mark.usefixtures("use_capture") @pytest.mark.parametrize( "num_iter, expected", [(3, [0, 0, 1, 0, 0, 0, 0, 0]), (4, [1, 0, 0, 0, 0, 0, 0, 0])] @@ -260,6 +285,51 @@ def circuit(N): assert np.allclose(expected, observed) +@pytest.mark.usefixtures("use_capture") +def test_dynamic_wire_alloc_forloop_outside(backend): + """ + Test passing dynamically allocated wires into a for loop. + """ + + @qjit(autograph=True) + @qml.qnode(qml.device(backend, wires=1)) + def circuit(): + with qml.allocate(1) as q: + qml.X(wires=q[0]) + for _ in range(3): + qml.CNOT(wires=[q[0], 0]) + + return qml.probs(wires=[0]) + + observed = circuit() + expected = [0, 1] + + assert np.allclose(expected, observed) + + +@pytest.mark.usefixtures("use_capture") +def test_dynamic_wire_alloc_forloop_outside_multiple_regs(backend): + """ + Test using multiple dynamically allocated registers from inside for loop. + """ + + @qjit(autograph=True) + @qml.qnode(qml.device(backend, wires=1)) + def circuit(): + with qml.allocate(1) as q1: + with qml.allocate(1) as q2: + for _ in range(3): + qml.CNOT(wires=[q1[0], 0]) + qml.CNOT(wires=[q2[0], 0]) + + return qml.probs(wires=[0]) + + observed = circuit() + expected = [1, 0] + + assert np.allclose(expected, observed) + + @pytest.mark.usefixtures("use_capture") @pytest.mark.parametrize( "num_iter, expected", [(3, [0, 0, 1, 0, 0, 0, 0, 0]), (4, [1, 0, 0, 0, 0, 0, 0, 0])] @@ -287,6 +357,83 @@ def circuit(N): assert np.allclose(expected, observed) +@pytest.mark.usefixtures("use_capture") +@pytest.mark.parametrize("num_iter, expected", [(3, [0, 1, 0, 0]), (4, [1, 0, 0, 0])]) +def test_dynamic_wire_alloc_whileloop_outside(num_iter, expected, backend): + """ + Test passing dynamically allocated wires into a while loop. + """ + + @qjit(autograph=True) + @qml.qnode(qml.device(backend, wires=2)) + def circuit(N): + i = 0 + with qml.allocate(1) as q1: + with qml.allocate(1) as q2: + qml.X(q1[0]) + while i < N: + qml.CNOT(wires=[q1[0], 1]) + qml.CNOT(wires=[q2[0], 1]) + i += 1 + + return qml.probs(wires=[0, 1]) + + observed = circuit(num_iter) + + assert np.allclose(expected, observed) + + +@pytest.mark.usefixtures("use_capture") +@pytest.mark.parametrize("flip_again, expected", [(True, [1, 0]), (False, [0, 1])]) +def test_subroutine(flip_again, expected, backend): + """ + Test passing dynamically allocated wires into a subroutine. + """ + + @subroutine + def flip(w): + qml.X(w) + qml.CNOT(wires=[w, 0]) + + @qjit + @qml.qnode(qml.device(backend, wires=1)) + def circuit(): + with qml.allocate(1) as q1: + with qml.allocate(1) as q2: + flip(q1[0]) + if flip_again: + flip(q2[0]) + return qml.probs(wires=[0]) + + observed = circuit() + assert np.allclose(expected, observed) + + +@pytest.mark.usefixtures("use_capture") +def test_subroutine_multiple_args(backend): + """ + Test passing dynamically allocated wires into a subroutine with multiple arguments. + """ + + @subroutine + def flip(w1, w2, theta): + qml.X(w1) + qml.X(w2) + qml.ctrl(qml.RX, (w1, w2))(theta, wires=0) + + @qjit + @qml.qnode(qml.device(backend, wires=1)) + def circuit(): + with qml.allocate(1) as q1: + with qml.allocate(2) as q2: + flip(q1[0], q2[1], jnp.pi) + return qml.probs(wires=[0]) + + observed = circuit() + expected = [0, 1] + assert np.allclose(expected, observed) + + def test_no_capture(backend): """ Test error message when used without capture. @@ -371,62 +518,5 @@ def circuit(): return qml.probs(q) -@pytest.mark.usefixtures("use_capture") -def test_unsupported_cross_scope_registers(backend): - """ - Scope jaxprs in Catalyst cannot take multiple registers yet. - Test that an error is raised when a dynamically allocated register in an outside scope - is being used from an inside scope. - """ - - with pytest.raises( - NotImplementedError, - match=textwrap.dedent( - """ - Dynamically allocated wires in a parent scope cannot be used in a child - scope yet. Please consider dynamical allocation inside the child scope. - """ - ), - ): - - @qjit(autograph=True) - @qml.qnode(qml.device(backend, wires=3)) - def circuit(): - wires = qml.allocate(3) - - for _ in range(3): - qml.X(wires=wires[0]) - - return qml.probs(wires=[0, 1, 2]) - - -@pytest.mark.usefixtures("use_capture") -def test_unsupported_subroutine(backend): - """ - Test that an error is raised when a dynamically allocated wire is passed into a subroutine. - """ - - with pytest.raises( - NotImplementedError, - match=textwrap.dedent( - """ - Dynamically allocated wires in a parent scope cannot be used in a child - scope yet. Please consider dynamical allocation inside the child scope. - """ - ), - ): - - @subroutine - def sub(_): - pass - - @qjit - @qml.qnode(qml.device(backend, wires=2)) - def circuit(): - with qml.allocate(1) as q: - sub(q[0]) - return qml.probs(wires=[0, 1]) - - if __name__ == "__main__": pytest.main(["-x", __file__])