Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 28 additions & 15 deletions pytensor/ipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from IPython.display import display

from pytensor.graph import FunctionGraph, Variable, rewrite_graph
from pytensor.graph.features import FullHistory
from pytensor.graph.features import AlreadyThere, FullHistory

Check warning on line 7 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L7

Added line #L7 was not covered by tests


class CodeBlockWidget(anywidget.AnyWidget):
Expand Down Expand Up @@ -45,29 +45,41 @@

class InteractiveRewrite:
"""
A class that wraps a graph history object with interactive widgets
to navigate through history and display the graph at each step.

Includes an option to display the reason for the last change.
Visualize a graph history through a series of rewrites.
"""

def __init__(self, fg, display_reason=True):
def __init__(

Check warning on line 51 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L51

Added line #L51 was not covered by tests
self,
fg,
display_reason=True,
rewrite_options: dict | None = None,
dprint_options: dict | None = None,
):
"""
Initialize with a history object that has a goto method
and tracks a FunctionGraph.

Parameters:
-----------
fg : FunctionGraph (or Variables)
The function graph to track
display_reason : bool, optional
Whether to display the reason for each rewrite
rewrite_options : dict, optional
Options for rewriting the graph. Defaults to {'include': ('fast_run',), 'exclude': ('inplace',)}
print_options : dict, optional
Print options passed to `debugprint` used to generate the text representation of the graph.
Useful options are {'print_shape': True, 'print_op_info': True}
"""
self.dprint_options = dprint_options or {}
self.rewrite_options = rewrite_options or dict(

Check warning on line 72 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L71-L72

Added lines #L71 - L72 were not covered by tests
include=("fast_run",), exclude=("inplace",)
)
self.history = FullHistory(callback=self._history_callback)
if not isinstance(fg, FunctionGraph):
outs = [fg] if isinstance(fg, Variable) else fg
fg = FunctionGraph(outputs=outs)
fg.attach_feature(self.history)
try:
fg.attach_feature(self.history)
except AlreadyThere:
self.history.end()

Check warning on line 82 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L79-L82

Added lines #L79 - L82 were not covered by tests

self.updating_from_callback = False # Flag to prevent recursion
self.code_widget = CodeBlockWidget(content="")
Expand Down Expand Up @@ -163,7 +175,7 @@
reason = ""
else:
reason = self.history.fw[self.history.pointer].reason
reason = getattr(reason, "name", str(reason))
reason = getattr(reason, "name", None) or str(reason)

Check warning on line 178 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L178

Added line #L178 was not covered by tests

self.reason_label.value = f"""
<div style='padding: 5px; margin-bottom: 10px; background-color: #e6f7ff; border-left: 4px solid #1890ff;'>
Expand All @@ -172,7 +184,9 @@
"""

# Update the graph display
self.code_widget.content = self.history.fg.dprint(file="str")
self.code_widget.content = self.history.fg.dprint(

Check warning on line 187 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L187

Added line #L187 was not covered by tests
file="str", **self.dprint_options
)

# Update slider range if history length has changed
history_len = len(self.history.fw) + 1
Expand All @@ -189,14 +203,13 @@
f"History: {self.history.pointer + 1}/{history_len - 1}"
)

def rewrite(self, *args, include=("fast_run",), exclude=("inplace",), **kwargs):
def rewrite(self, *args, **kwargs):

Check warning on line 206 in pytensor/ipython.py

View check run for this annotation

Codecov / codecov/patch

pytensor/ipython.py#L206

Added line #L206 was not covered by tests
"""Apply rewrites to the current graph"""
rewrite_graph(
self.history.fg,
*args,
include=include,
exclude=exclude,
**kwargs,
**self.rewrite_options,
clone=False,
)
self._update_display()