Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 110 additions & 28 deletions frontend/catalyst/from_plxpr/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,25 +26,70 @@
from pennylane.capture.primitives import while_loop_prim as plxpr_while_loop_prim

from catalyst.from_plxpr.from_plxpr import PLxPRToQuantumJaxprInterpreter, WorkflowInterpreter
from catalyst.from_plxpr.qubit_handler import QubitHandler, QubitIndexRecorder
from catalyst.from_plxpr.qubit_handler import (
QubitHandler,
QubitIndexRecorder,
_get_dynamically_allocated_qregs,
)
from catalyst.jax_extras import jaxpr_pad_consts
from catalyst.jax_primitives import cond_p, for_p, while_p


def _calling_convention(interpreter, closed_jaxpr, *args_plus_qreg):
# The last arg is the scope argument for the body jaxpr
*args, qreg = args_plus_qreg
def _calling_convention(
interpreter, closed_jaxpr, *args_plus_qregs, outer_dynqreg_handlers=(), return_qreg=True
):
# Arg structure (all args are tracers, since this function is to be `make_jaxpr`'d):
# Regular args, then dynamically allocated qregs, then global qreg
# TODO: merge dynamically allocaed qregs into regular args?
# But this is tricky, since qreg arguments need all the SSA value semantics conversion infra
# and are different from the regular plain arguments.
*args_plus_dynqregs, global_qreg = args_plus_qregs
num_dynamic_alloced_qregs = len(outer_dynqreg_handlers)
args, dynalloced_qregs = (
args_plus_dynqregs[: len(args_plus_dynqregs) - num_dynamic_alloced_qregs],
args_plus_dynqregs[len(args_plus_dynqregs) - num_dynamic_alloced_qregs :],
)

# Launch a new interpreter for the body region
# A new interpreter's root qreg value needs a new recorder
converter = copy(interpreter)
converter.qubit_index_recorder = QubitIndexRecorder()
init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder)
converter.init_qreg = init_qreg

# pylint: disable-next=cell-var-from-loop
# add dynamic qregs to recorder
qreg_map = {}
dyn_qreg_handlers = []
for dyn_qreg, outer_dynqreg_handler in zip(
dynalloced_qregs, outer_dynqreg_handlers, strict=True
):
dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder)
dyn_qreg_handlers.append(dyn_qreg_handler)

# plxpr global wire index does not change across scopes
# So scope arg dynamic qregs need to have the same root hash as their corresponding
# qreg tracers outside
dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash

# Each qreg argument of the subscope corresponds to a qreg from the outer scope
qreg_map[outer_dynqreg_handler] = dyn_qreg_handler

# The new interpreter's recorder needs to be updated to include the qreg args
# of this scope, instead of the outer qregs
if qreg_map:
for k, outer_dynqreg_handler in interpreter.qubit_index_recorder.map.items():
converter.qubit_index_recorder[k] = qreg_map[outer_dynqreg_handler]

retvals = converter(closed_jaxpr, *args)
if not return_qreg:
return retvals

init_qreg.insert_all_dangling_qubits()

# Return all registers
for dyn_qreg_handler in dyn_qreg_handlers:
dyn_qreg_handler.insert_all_dangling_qubits()
retvals.append(dyn_qreg_handler.get())
return *retvals, converter.init_qreg.get()


Expand Down Expand Up @@ -90,7 +135,18 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
"""Handle the conversion from plxpr to Catalyst jaxpr for the cond primitive"""
args = plxpr_invals[args_slice]
self.init_qreg.insert_all_dangling_qubits()
args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args

dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)

# Add the qregs to the args
args_plus_qreg = [
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
]

converted_jaxpr_branches = []
all_consts = []

Expand All @@ -103,7 +159,9 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
converted_jaxpr_branch = None
closed_jaxpr = ClosedJaxpr(plxpr_branch, branch_consts)

f = partial(_calling_convention, self, closed_jaxpr)
f = partial(
_calling_convention, self, closed_jaxpr, outer_dynqreg_handlers=dynalloced_qregs
)
converted_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg)

all_consts += converted_jaxpr_branch.consts
Expand All @@ -112,6 +170,8 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
predicate = [_to_bool_if_not(p) for p in plxpr_invals[: len(jaxpr_branches) - 1]]

# Build Catalyst compatible input values
# strip global wire indices of dynamic wires
all_consts = tuple(const for const in all_consts if const not in dynalloced_wire_global_indices)
cond_invals = [*predicate, *all_consts, *args_plus_qreg]

# Perform the binding
Expand All @@ -121,9 +181,12 @@ def handle_cond(self, *plxpr_invals, jaxpr_branches, consts_slices, args_slice):
nimplicit_outputs=None,
)

# We assume the last output value is the returned qreg.
# Output structure:
# First a list of dynamically allocated qregs, then the global qreg
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())
for dyn_qreg in reversed(dynalloced_qregs):
dyn_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals
Expand Down Expand Up @@ -193,25 +256,38 @@ def handle_for_loop(

# Add the iteration start and the qreg to the args
self.init_qreg.insert_all_dangling_qubits()

dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)

start_plus_args_plus_qreg = [
start,
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
]

consts = plxpr_invals[consts_slice]

jaxpr = ClosedJaxpr(jaxpr_body_fn, consts)

f = partial(_calling_convention, self, jaxpr)
f = partial(
_calling_convention,
self,
jaxpr,
outer_dynqreg_handlers=dynalloced_qregs,
)
converted_jaxpr_branch = jax.make_jaxpr(f)(*start_plus_args_plus_qreg)

converted_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_jaxpr_branch.jaxpr), ()
)

# Build Catalyst compatible input values
# strip global wire indices of dynamic wires
new_consts = converted_jaxpr_branch.consts
new_consts = tuple(const for const in new_consts if const not in dynalloced_wire_global_indices)
for_loop_invals = [*new_consts, start, stop, step, *start_plus_args_plus_qreg]

# Config additional for loop settings
Expand All @@ -227,10 +303,14 @@ def handle_for_loop(
preserve_dimensions=True,
)

# We assume the last output value is the returned qreg.
# Output structure:
# First a list of dynamically allocated qregs, then the global qreg
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())

for dyn_qreg in reversed(dynalloced_qregs):
dyn_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals

Expand Down Expand Up @@ -289,14 +369,21 @@ def handle_while_loop(
):
"""Handle the conversion from plxpr to Catalyst jaxpr for the while loop primitive"""
self.init_qreg.insert_all_dangling_qubits()
dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
plxpr_invals, self.qubit_index_recorder, self.init_qreg
)
consts_body = plxpr_invals[body_slice]
consts_cond = plxpr_invals[cond_slice]
args = plxpr_invals[args_slice]
args_plus_qreg = [*args, self.init_qreg.get()] # Add the qreg to the args
args_plus_qreg = [
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
self.init_qreg.get(),
] # Add the qreg to the args

jaxpr = ClosedJaxpr(jaxpr_body_fn, consts_body)

f = partial(_calling_convention, self, jaxpr)
f = partial(_calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs)
converted_body_jaxpr_branch = jax.make_jaxpr(f)(*args_plus_qreg).jaxpr

converted_body_closed_jaxpr_branch = ClosedJaxpr(
Expand All @@ -307,30 +394,22 @@ def handle_while_loop(
# We need to be able to handle arbitrary plxpr here.
# But we want to be able to create a state where:
# * We do not pass the quantum register as an argument.

# So let's just remove the quantum register here at the end

jaxpr = ClosedJaxpr(jaxpr_cond_fn, consts_cond)

def remove_qreg(*args_plus_qreg):
# The last arg is the scope argument for the body jaxpr
*args, qreg = args_plus_qreg

# Launch a new interpreter for the body region
# A new interpreter's root qreg value needs a new recorder
converter = copy(self)
converter.qubit_index_recorder = QubitIndexRecorder()
init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
converter.init_qreg = init_qreg

return converter(jaxpr, *args)
f_remove_qreg = partial(
_calling_convention, self, jaxpr, outer_dynqreg_handlers=dynalloced_qregs, return_qreg=False
)

converted_cond_jaxpr_branch = jax.make_jaxpr(remove_qreg)(*args_plus_qreg).jaxpr
converted_cond_jaxpr_branch = jax.make_jaxpr(f_remove_qreg)(*args_plus_qreg).jaxpr
converted_cond_closed_jaxpr_branch = ClosedJaxpr(
convert_constvars_jaxpr(converted_cond_jaxpr_branch), ()
)

# Build Catalyst compatible input values
consts_body = tuple(
const for const in consts_body if const not in dynalloced_wire_global_indices
)
while_loop_invals = [*consts_cond, *consts_body, *args_plus_qreg]

# Perform the binding
Expand All @@ -348,5 +427,8 @@ def remove_qreg(*args_plus_qreg):
# Update the current qreg and remove it from the output values.
self.init_qreg.set(outvals.pop())

for dyn_qreg in reversed(dynalloced_qregs):
dyn_qreg.set(outvals.pop())

# Return only the output values that match the plxpr output values
return outvals
89 changes: 68 additions & 21 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from catalyst.from_plxpr.qubit_handler import (
QubitHandler,
QubitIndexRecorder,
_get_dynamically_allocated_qregs,
get_in_qubit_values,
is_dynamically_allocated_wire,
)
Expand Down Expand Up @@ -642,37 +643,80 @@ def handle_subroutine(self, *args, **kwargs):
Transform the subroutine from PLxPR into JAXPR with quantum primitives.
"""

if any(is_dynamically_allocated_wire(arg) for arg in args):
raise NotImplementedError(
textwrap.dedent(
"""
Dynamically allocated wires in a parent scope cannot be used in a child
scope yet. Please consider dynamical allocation inside the child scope.
"""
)
)

backup = dict(self.init_qreg)
self.init_qreg.insert_all_dangling_qubits()

# Make sure the quantum register is updated
plxpr = kwargs["jaxpr"]
transformed = self.subroutine_cache.get(plxpr)

def wrapper(qreg, *args):
# Launch a new interpreter for the new subroutine region
dynalloced_qregs, dynalloced_wire_global_indices = _get_dynamically_allocated_qregs(
args, self.qubit_index_recorder, self.init_qreg
)

# Convert global wire indices into local indices
new_args = ()
wire_label_arg_to_tracer_arg_index = {}
for i, arg in enumerate(args):
if arg in dynalloced_wire_global_indices:
wire_label_arg_to_tracer_arg_index[arg] = i
new_args += (self.qubit_index_recorder[arg].global_index_to_local_index(arg),)
else:
new_args += (arg,)

def wrapper(*qregs_plus_args):
global_qreg, *dynqregs_plus_args = qregs_plus_args
num_dynamic_alloced_qregs = len(dynalloced_qregs)
_dynalloced_qregs, args = (
dynqregs_plus_args[:num_dynamic_alloced_qregs],
dynqregs_plus_args[num_dynamic_alloced_qregs:],
)

# Launch a new interpreter for the body region
# A new interpreter's root qreg value needs a new recorder
converter = copy(self)
converter.qubit_index_recorder = QubitIndexRecorder()
init_qreg = QubitHandler(qreg, converter.qubit_index_recorder)
init_qreg = QubitHandler(global_qreg, converter.qubit_index_recorder)
converter.init_qreg = init_qreg

# add dynamic qregs to recorder
qreg_map = {}
dyn_qreg_handlers = []
for dyn_qreg, outer_dynqreg_handler, global_wire_index in zip(
_dynalloced_qregs, dynalloced_qregs, dynalloced_wire_global_indices, strict=True
):
dyn_qreg_handler = QubitHandler(dyn_qreg, converter.qubit_index_recorder)
dyn_qreg_handlers.append(dyn_qreg_handler)

# plxpr global wire index does not change across scopes
# So scope arg dynamic qregs need to have the same root hash as their corresponding
# qreg tracers outside
dyn_qreg_handler.root_hash = outer_dynqreg_handler.root_hash

# Each qreg argument of the subscope corresponds to a qreg from the outer scope
qreg_map[args[wire_label_arg_to_tracer_arg_index[global_wire_index]]] = dyn_qreg_handler

# The new interpreter's recorder needs to be updated to include the qreg args
# of this scope, instead of the outer qregs
for arg in args:
if arg in qreg_map:
converter.qubit_index_recorder[arg] = qreg_map[arg]

retvals = converter(plxpr, *args)
converter.init_qreg.insert_all_dangling_qubits()

init_qreg.insert_all_dangling_qubits()

# Return all registers
for dyn_qreg_handler in reversed(dyn_qreg_handlers):
dyn_qreg_handler.insert_all_dangling_qubits()
retvals.insert(0, dyn_qreg_handler.get())

return converter.init_qreg.get(), *retvals

if not transformed:
converted_closed_jaxpr_branch = jax.make_jaxpr(wrapper)(self.init_qreg.get(), *args)
converted_closed_jaxpr_branch = jax.make_jaxpr(wrapper)(
self.init_qreg.get(), *[dyn_qreg.get() for dyn_qreg in dynalloced_qregs], *args
)
self.subroutine_cache[plxpr] = converted_closed_jaxpr_branch
else:
converted_closed_jaxpr_branch = transformed
Expand All @@ -681,12 +725,13 @@ def wrapper(qreg, *args):
# is just pjit_p with a different name.
vals_out = quantum_subroutine_p.bind(
self.init_qreg.get(),
*args,
*[dyn_qreg.get() for dyn_qreg in dynalloced_qregs],
*new_args,
jaxpr=converted_closed_jaxpr_branch,
in_shardings=(UNSPECIFIED, *kwargs["in_shardings"]),
out_shardings=(UNSPECIFIED, *kwargs["out_shardings"]),
in_layouts=(None, *kwargs["in_layouts"]),
out_layouts=(None, *kwargs["out_layouts"]),
in_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["in_shardings"]),
out_shardings=(*(UNSPECIFIED,) * (len(dynalloced_qregs) + 1), *kwargs["out_shardings"]),
in_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["in_layouts"]),
out_layouts=(*(None,) * (len(dynalloced_qregs) + 1), *kwargs["out_layouts"]),
donated_invars=kwargs["donated_invars"],
ctx_mesh=kwargs["ctx_mesh"],
name=kwargs["name"],
Expand All @@ -696,7 +741,9 @@ def wrapper(qreg, *args):
)

self.init_qreg.set(vals_out[0])
vals_out = vals_out[1:]
for i, dyn_qreg in enumerate(dynalloced_qregs):
dyn_qreg.set(vals_out[i + 1])
vals_out = vals_out[len(dynalloced_qregs) + 1 :]

for orig_wire in backup.keys():
self.init_qreg.extract(orig_wire)
Expand Down
Loading