diff --git a/pytensor/ipython.py b/pytensor/ipython.py index 33adf5792d..9fd50d1443 100644 --- a/pytensor/ipython.py +++ b/pytensor/ipython.py @@ -4,7 +4,7 @@ from IPython.display import display from pytensor.graph import FunctionGraph, Variable, rewrite_graph -from pytensor.graph.features import FullHistory +from pytensor.graph.features import AlreadyThere, FullHistory class CodeBlockWidget(anywidget.AnyWidget): @@ -45,29 +45,41 @@ class CodeBlockWidget(anywidget.AnyWidget): class InteractiveRewrite: """ - A class that wraps a graph history object with interactive widgets - to navigate through history and display the graph at each step. - - Includes an option to display the reason for the last change. + Visualize a graph history through a series of rewrites. """ - def __init__(self, fg, display_reason=True): + def __init__( + self, + fg, + display_reason=True, + rewrite_options: dict | None = None, + dprint_options: dict | None = None, + ): """ - Initialize with a history object that has a goto method - and tracks a FunctionGraph. - Parameters: ----------- fg : FunctionGraph (or Variables) The function graph to track display_reason : bool, optional Whether to display the reason for each rewrite + rewrite_options : dict, optional + Options for rewriting the graph. Defaults to {'include': ('fast_run',), 'exclude': ('inplace',)} + print_options : dict, optional + Print options passed to `debugprint` used to generate the text representation of the graph. + Useful options are {'print_shape': True, 'print_op_info': True} """ + self.dprint_options = dprint_options or {} + self.rewrite_options = rewrite_options or dict( + include=("fast_run",), exclude=("inplace",) + ) self.history = FullHistory(callback=self._history_callback) if not isinstance(fg, FunctionGraph): outs = [fg] if isinstance(fg, Variable) else fg fg = FunctionGraph(outputs=outs) - fg.attach_feature(self.history) + try: + fg.attach_feature(self.history) + except AlreadyThere: + self.history.end() self.updating_from_callback = False # Flag to prevent recursion self.code_widget = CodeBlockWidget(content="") @@ -163,7 +175,7 @@ def _update_display(self): reason = "" else: reason = self.history.fw[self.history.pointer].reason - reason = getattr(reason, "name", str(reason)) + reason = getattr(reason, "name", None) or str(reason) self.reason_label.value = f"""