Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions frontend/catalyst/from_plxpr/from_plxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
import pennylane as qml
from jax.extend.core import ClosedJaxpr, Jaxpr
from jax.extend.linear_util import wrap_init
from pennylane.capture import PlxprInterpreter, qnode_prim
from pennylane.capture import PlxprInterpreter, qnode_prim, enable, disable, enabled
from pennylane.capture.expand_transforms import ExpandTransformsInterpreter
from pennylane.capture.primitives import jacobian_prim as pl_jac_prim
from pennylane.capture.primitives import transform_prim
Expand Down Expand Up @@ -414,9 +414,14 @@ def trace_from_pennylane(
# https://github.com/jax-ml/jax/blob/636691bba40b936b8b64a4792c1d2158296e9dd4/jax/_src/linear_util.py#L231
# Therefore we need to coordinate them manually
fn.static_argnums = static_argnums

plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)
jaxpr = from_plxpr(plxpr)(*plxpr.in_avals)
capture_on = enabled()
enable()
try:
plxpr, out_type, out_treedef = make_jaxpr2(fn, **make_jaxpr_kwargs)(*args, **kwargs)
jaxpr = from_plxpr(plxpr)(*plxpr.in_avals)
finally:
if not capture_on:
disable()

return jaxpr, out_type, out_treedef, sig

Expand Down
5 changes: 3 additions & 2 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def qjit(
*,
autograph=False,
autograph_include=(),
experimental_capture=False,
async_qnodes=False,
target="binary",
keep_intermediate=False,
Expand Down Expand Up @@ -698,7 +699,7 @@ def jit_compile(self, args, **kwargs):
def pre_compilation(self):
"""Perform pre-processing tasks on the Python function, such as AST transformations."""
if self.compile_options.autograph:
if qml.capture.enabled():
if self.compile_options.experimental_capture:
if self.compile_options.autograph_include:
raise NotImplementedError(
"capture autograph does not yet support autograph_include."
Expand Down Expand Up @@ -731,7 +732,7 @@ def capture(self, args, **kwargs):

dbg = debug_info("qjit_capture", self.user_function, args, kwargs)

if qml.capture.enabled():
if self.compile_options.experimental_capture:
with Patcher(
(
jax._src.interpreters.partial_eval, # pylint: disable=protected-access
Expand Down
4 changes: 4 additions & 0 deletions frontend/catalyst/pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,12 @@ class CompileOptions:
Default is ``None``.
pass_plugins (Optional[Iterable[Path]]): List of paths to pass plugins.
dialect_plugins (Optional[Iterable[Path]]): List of paths to dialect plugins.
experimental_capture (bool): If set to ``True``,
use PennyLane's experimental program capture capabilities
to capture the function for compilation.
"""

experimental_capture : bool = False
verbose: Optional[bool] = False
logfile: Optional[TextIOWrapper] = sys.stderr
target: Optional[str] = "binary"
Expand Down
3 changes: 3 additions & 0 deletions frontend/test/pytest.ini
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
[pytest]
markers =
capture_only: marks tests for capture only
old_frontend: tests for capture off
xfail_strict=true
6 changes: 4 additions & 2 deletions frontend/test/pytest/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,15 +62,17 @@
@pytest.fixture(scope="function")
def use_capture_dgraph():
"""Enable capture and graph-decomposition before and disable them both after the test."""
qml.capture.enable()
qml.decomposition.enable_graph()
try:
yield
finally:
qml.decomposition.disable_graph()
qml.capture.disable()


@pytest.fixture(params=[True, False], scope="function")
def experimental_capture(request):

Check notice on line 73 in frontend/test/pytest/conftest.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/test/pytest/conftest.py#L73

Missing function or method docstring (missing-function-docstring)
yield request.param

@pytest.fixture(params=["capture", "no_capture"], scope="function")
def use_both_frontend(request):
"""Runs the test once with capture enabled and once with it disabled."""
Expand Down
Loading
Loading