1818
1919
2020import warnings
21+ from copy import copy
2122from functools import partial
2223from typing import Callable
2324
@@ -137,6 +138,14 @@ def f(x):
137138class WorkflowInterpreter (PlxprInterpreter ):
138139 """An interpreter that converts a qnode primitive from a plxpr variant to a catalyst jaxpr variant."""
139140
141+ def __copy__ (self ):
142+ new_version = WorkflowInterpreter ()
143+ new_version ._pass_pipeline = copy (self ._pass_pipeline )
144+ new_version .init_qreg = self .init_qreg
145+ new_version .requires_decompose_lowering = self .requires_decompose_lowering
146+ new_version .decompose_tkwargs = copy (self .decompose_tkwargs )
147+ return new_version
148+
140149 def __init__ (self ):
141150 self ._pass_pipeline = []
142151 self .init_qreg = None
@@ -284,12 +293,13 @@ def handle_transform(
284293 "Multiple decomposition transforms are not yet supported."
285294 )
286295
296+ next_eval = copy (self )
287297 # Update the decompose_gateset to be used by the quantum kernel primitive
288298 # TODO: we originally wanted to treat decompose_gateset as a queue of
289299 # gatesets to be used by the decompose-lowering pass at MLIR
290300 # but this requires a C++ implementation of the graph-based decomposition
291301 # which doesn't exist yet.
292- self .decompose_tkwargs = tkwargs
302+ next_eval .decompose_tkwargs = tkwargs
293303
294304 # Note. We don't perform the compiler-specific decomposition here
295305 # to be able to support multiple decomposition transforms
@@ -300,7 +310,7 @@ def handle_transform(
300310 # in the qnode handler.
301311
302312 # Add the decompose-lowering pass to the start of the pipeline
303- self ._pass_pipeline .insert (0 , Pass ("decompose-lowering" ))
313+ next_eval ._pass_pipeline .insert (0 , Pass ("decompose-lowering" ))
304314
305315 # We still need to construct and solve the graph based on
306316 # the current jaxpr based on the current gateset
@@ -313,7 +323,7 @@ def handle_transform(
313323
314324 # final_jaxpr = jax.make_jaxpr(gds_wrapper)(*args)
315325 # return self.eval(final_jaxpr.jaxpr, consts, *non_const_args)
316- return self .eval (inner_jaxpr , consts , * non_const_args )
326+ return next_eval .eval (inner_jaxpr , consts , * non_const_args )
317327
318328 if catalyst_pass_name is None :
319329 # Use PL's ExpandTransformsInterpreter to expand this and any embedded
@@ -333,11 +343,12 @@ def wrapper(*args):
333343 final_jaxpr .jaxpr , final_jaxpr .consts , targs , tkwargs , * non_const_args
334344 )
335345
336- return self .eval (final_jaxpr .jaxpr , final_jaxpr .consts , * non_const_args )
346+ return copy ( self ) .eval (final_jaxpr .jaxpr , final_jaxpr .consts , * non_const_args )
337347
338348 # Apply the corresponding Catalyst pass counterpart
339- self ._pass_pipeline .insert (0 , Pass (catalyst_pass_name , * targs , ** tkwargs ))
340- return self .eval (inner_jaxpr , consts , * non_const_args )
349+ next_eval = copy (self )
350+ next_eval ._pass_pipeline .insert (0 , Pass (catalyst_pass_name , * targs , ** tkwargs ))
351+ return next_eval .eval (inner_jaxpr , consts , * non_const_args )
341352
342353
343354# This is our registration factory for PL transforms. The loop below iterates
0 commit comments