diff --git a/doc/releases/changelog-dev.md b/doc/releases/changelog-dev.md index 8731ff9a1b..063de07b42 100644 --- a/doc/releases/changelog-dev.md +++ b/doc/releases/changelog-dev.md @@ -7,6 +7,7 @@ [(#2229)](https://github.com/PennyLaneAI/catalyst/pull/2229) [(#2214)](https://github.com/PennyLaneAI/catalyst/pull/2214) [(#2246)](https://github.com/PennyLaneAI/catalyst/pull/2246) + [(#2243)](https://github.com/PennyLaneAI/catalyst/pull/2243) * Added ``catalyst.switch``, a qjit compatible, index-switch style control flow decorator. [(#2171)](https://github.com/PennyLaneAI/catalyst/pull/2171) diff --git a/frontend/catalyst/python_interface/inspection/draw.py b/frontend/catalyst/python_interface/inspection/draw.py index 5bee87b522..4dc251a280 100644 --- a/frontend/catalyst/python_interface/inspection/draw.py +++ b/frontend/catalyst/python_interface/inspection/draw.py @@ -15,15 +15,20 @@ from __future__ import annotations +import io import warnings from functools import wraps from typing import TYPE_CHECKING +import matplotlib.image as mpimg +import matplotlib.pyplot as plt from pennylane.tape import QuantumScript from catalyst.python_interface.compiler import Compiler from .collector import QMLCollector +from .construct_circuit_dag import ConstructCircuitDAG +from .pydot_dag_builder import PyDotDAGBuilder from .xdsl_conversion import get_mlir_module if TYPE_CHECKING: @@ -90,3 +95,53 @@ def wrapper(*args, **kwargs): return cache.get(level, cache[max(cache.keys())])[0] return wrapper + + +def draw_graph(qnode: QNode, *, level: None | int = None) -> Callable: + """ + ??? + """ + cache: dict[int, tuple[str, str]] = _cache_store.setdefault(qnode, {}) + + def _draw_callback(previous_pass, module, next_pass, pass_level=0): + """Callback function for circuit drawing.""" + + pass_instance = previous_pass if previous_pass else next_pass + # Process module to build DAG + utility = ConstructCircuitDAG(PyDotDAGBuilder()) + utility.construct(module) + # Store DAG in cache + image_bytes = utility.dag_builder.graph.create_png(prog="dot") + pass_name = pass_instance.name if hasattr(pass_instance, "name") else pass_instance + cache[pass_level] = ( + image_bytes, + pass_name if pass_level else "No transforms", + ) + + @wraps(qnode) + def wrapper(*args, **kwargs): + mlir_module = _get_mlir_module(qnode, args, kwargs) + Compiler.run(mlir_module, callback=_draw_callback) + + if not cache: + return None + + # Retrieve Data (Fall back to highest level if 'level' is not found) + max_level = max(cache.keys()) + image_bytes, pass_name = cache.get(level, cache[max_level]) + + # Render image bytes to matplotlib + sio = io.BytesIO() + sio.write(image_bytes) + sio.seek(0) + + img = mpimg.imread(sio) + + fig, ax = plt.subplots() + ax.imshow(img) + ax.set_axis_off() + ax.set_title(f"Level {level if level is not None else max_level}: {pass_name}", fontsize=10) + + return fig, ax + + return wrapper