|
38 | 38 | from .thunk import * |
39 | 39 |
|
40 | 40 | if TYPE_CHECKING: |
41 | | - import ipywidgets |
42 | | - |
43 | 41 | from .builtins import Bool, PyObject, String, f64, i64 |
44 | 42 |
|
45 | 43 |
|
@@ -973,6 +971,7 @@ class GraphvizKwargs(TypedDict, total=False): |
973 | 971 | n_inline_leaves: int |
974 | 972 | split_primitive_outputs: bool |
975 | 973 | split_functions: list[object] |
| 974 | + include_temporary_functions: bool |
976 | 975 |
|
977 | 976 |
|
978 | 977 | @dataclass |
@@ -1015,82 +1014,8 @@ def as_egglog_string(self) -> str: |
1015 | 1014 | raise ValueError(msg) |
1016 | 1015 | return cmds |
1017 | 1016 |
|
1018 | | - def _repr_mimebundle_(self, *args, **kwargs): |
1019 | | - """ |
1020 | | - Returns the graphviz representation of the e-graph. |
1021 | | - """ |
1022 | | - return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")} |
1023 | | - |
1024 | | - def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: |
1025 | | - # By default we want to split primitive outputs |
1026 | | - split_primitive_outputs = kwargs.pop("split_primitive_outputs", True) |
1027 | | - split_additional_functions = kwargs.pop("split_functions", []) |
1028 | | - n_inline = kwargs.pop("n_inline_leaves", 0) |
1029 | | - serialized = self._egraph.serialize( |
1030 | | - [], |
1031 | | - max_functions=kwargs.pop("max_functions", None), |
1032 | | - max_calls_per_function=kwargs.pop("max_calls_per_function", None), |
1033 | | - include_temporary_functions=False, |
1034 | | - ) |
1035 | | - if split_primitive_outputs or split_additional_functions: |
1036 | | - additional_ops = set(map(self._callable_to_egg, split_additional_functions)) |
1037 | | - serialized.split_e_classes(self._egraph, additional_ops) |
1038 | | - serialized.map_ops(self._state.op_mapping()) |
1039 | | - |
1040 | | - for _ in range(n_inline): |
1041 | | - serialized.inline_leaves() |
1042 | | - original = serialized.to_dot() |
1043 | | - # Add link to stylesheet to the graph, so that edges light up on hover |
1044 | | - # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682 |
1045 | | - styles = """/* the lines within the edges */ |
1046 | | - .edge:active path, |
1047 | | - .edge:hover path { |
1048 | | - stroke: fuchsia; |
1049 | | - stroke-width: 3; |
1050 | | - stroke-opacity: 1; |
1051 | | - } |
1052 | | - /* arrows are typically drawn with a polygon */ |
1053 | | - .edge:active polygon, |
1054 | | - .edge:hover polygon { |
1055 | | - stroke: fuchsia; |
1056 | | - stroke-width: 3; |
1057 | | - fill: fuchsia; |
1058 | | - stroke-opacity: 1; |
1059 | | - fill-opacity: 1; |
1060 | | - } |
1061 | | - /* If you happen to have text and want to color that as well... */ |
1062 | | - .edge:active text, |
1063 | | - .edge:hover text { |
1064 | | - fill: fuchsia; |
1065 | | - }""" |
1066 | | - p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css" |
1067 | | - p.write_text(styles) |
1068 | | - with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1) |
1069 | | - return graphviz.Source(with_stylesheet) |
1070 | | - |
1071 | | - def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str: |
1072 | | - return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8") |
1073 | | - |
1074 | | - def _repr_html_(self) -> str: |
1075 | | - """ |
1076 | | - Add a _repr_html_ to be an SVG to work with sphinx gallery. |
1077 | | -
|
1078 | | - ala https://github.com/xflr6/graphviz/pull/121 |
1079 | | - until this PR is merged and released |
1080 | | - https://github.com/sphinx-gallery/sphinx-gallery/pull/1138 |
1081 | | - """ |
1082 | | - return self.graphviz_svg() |
1083 | | - |
1084 | | - def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None: |
1085 | | - """ |
1086 | | - Displays the e-graph in the notebook. |
1087 | | - """ |
1088 | | - if IN_IPYTHON: |
1089 | | - from IPython.display import SVG, display |
1090 | | - |
1091 | | - display(SVG(self.graphviz_svg(**kwargs))) |
1092 | | - else: |
1093 | | - self.graphviz(**kwargs).render(view=True, format="svg", quiet=True) |
| 1017 | + def _ipython_display_(self) -> None: |
| 1018 | + self.display() |
1094 | 1019 |
|
1095 | 1020 | def input(self, fn: Callable[..., String], path: str) -> None: |
1096 | 1021 | """ |
@@ -1319,40 +1244,100 @@ def eval(self, expr: Expr) -> object: |
1319 | 1244 | return self._egraph.eval_py_object(egg_expr) |
1320 | 1245 | raise TypeError(f"Eval not implemented for {typed_expr.tp}") |
1321 | 1246 |
|
1322 | | - def saturate( |
| 1247 | + def _serialize( |
1323 | 1248 | self, |
1324 | | - schedule: Schedule | None = None, |
1325 | | - *, |
1326 | | - max: int = 1000, |
1327 | | - performance: bool = False, |
1328 | 1249 | **kwargs: Unpack[GraphvizKwargs], |
1329 | | - ) -> ipywidgets.Widget: |
1330 | | - from .graphviz_widget import graphviz_widget_with_slider |
| 1250 | + ) -> bindings.SerializedEGraph: |
| 1251 | + max_functions = kwargs.pop("max_functions", None) |
| 1252 | + max_calls_per_function = kwargs.pop("max_calls_per_function", None) |
| 1253 | + split_primitive_outputs = kwargs.pop("split_primitive_outputs", True) |
| 1254 | + split_functions = kwargs.pop("split_functions", []) |
| 1255 | + include_temporary_functions = kwargs.pop("include_temporary_functions", False) |
| 1256 | + n_inline_leaves = kwargs.pop("n_inline_leaves", 1) |
| 1257 | + serialized = self._egraph.serialize( |
| 1258 | + [], |
| 1259 | + max_functions=max_functions, |
| 1260 | + max_calls_per_function=max_calls_per_function, |
| 1261 | + include_temporary_functions=include_temporary_functions, |
| 1262 | + ) |
| 1263 | + if split_primitive_outputs or split_functions: |
| 1264 | + additional_ops = set(map(self._callable_to_egg, split_functions)) |
| 1265 | + serialized.split_e_classes(self._egraph, additional_ops) |
| 1266 | + serialized.map_ops(self._state.op_mapping()) |
1331 | 1267 |
|
1332 | | - dots = [str(self.graphviz(**kwargs))] |
1333 | | - i = 0 |
1334 | | - while self.run(schedule or 1).updated and i < max: |
1335 | | - i += 1 |
1336 | | - dots.append(str(self.graphviz(**kwargs))) |
1337 | | - return graphviz_widget_with_slider(dots, performance=performance) |
| 1268 | + for _ in range(n_inline_leaves): |
| 1269 | + serialized.inline_leaves() |
1338 | 1270 |
|
1339 | | - def saturate_to_html( |
1340 | | - self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs] |
1341 | | - ) -> None: |
1342 | | - # raise NotImplementedError("Upstream bugs prevent rendering to HTML") |
| 1271 | + return serialized |
1343 | 1272 |
|
1344 | | - # import panel |
| 1273 | + def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: |
| 1274 | + serialized = self._serialize(**kwargs) |
1345 | 1275 |
|
1346 | | - # panel.extension("ipywidgets") |
| 1276 | + original = serialized.to_dot() |
| 1277 | + # Add link to stylesheet to the graph, so that edges light up on hover |
| 1278 | + # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682 |
| 1279 | + styles = """/* the lines within the edges */ |
| 1280 | + .edge:active path, |
| 1281 | + .edge:hover path { |
| 1282 | + stroke: fuchsia; |
| 1283 | + stroke-width: 3; |
| 1284 | + stroke-opacity: 1; |
| 1285 | + } |
| 1286 | + /* arrows are typically drawn with a polygon */ |
| 1287 | + .edge:active polygon, |
| 1288 | + .edge:hover polygon { |
| 1289 | + stroke: fuchsia; |
| 1290 | + stroke-width: 3; |
| 1291 | + fill: fuchsia; |
| 1292 | + stroke-opacity: 1; |
| 1293 | + fill-opacity: 1; |
| 1294 | + } |
| 1295 | + /* If you happen to have text and want to color that as well... */ |
| 1296 | + .edge:active text, |
| 1297 | + .edge:hover text { |
| 1298 | + fill: fuchsia; |
| 1299 | + }""" |
| 1300 | + p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css" |
| 1301 | + p.write_text(styles) |
| 1302 | + with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1) |
| 1303 | + return graphviz.Source(with_stylesheet) |
1347 | 1304 |
|
1348 | | - widget = self.saturate(performance=performance, **kwargs) |
1349 | | - # panel.panel(widget).save(file) |
| 1305 | + def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None: |
| 1306 | + """ |
| 1307 | + Displays the e-graph. |
| 1308 | +
|
| 1309 | + If in IPython it will display it inline, otherwise it will write it to a file and open it. |
| 1310 | + """ |
| 1311 | + from IPython.display import SVG, display |
| 1312 | + |
| 1313 | + from .visualizer_widget import VisualizerWidget |
| 1314 | + |
| 1315 | + if graphviz: |
| 1316 | + if IN_IPYTHON: |
| 1317 | + svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8") |
| 1318 | + display(SVG(svg)) |
| 1319 | + else: |
| 1320 | + self._graphviz(**kwargs).render(view=True, format="svg", quiet=True) |
| 1321 | + else: |
| 1322 | + serialized = self._serialize(**kwargs) |
| 1323 | + VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open() |
1350 | 1324 |
|
1351 | | - from ipywidgets.embed import embed_minimal_html |
| 1325 | + def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None: |
| 1326 | + """ |
| 1327 | + Saturate the egraph, running the given schedule until the egraph is saturated. |
| 1328 | + It serializes the egraph at each step and returns a widget to visualize the egraph. |
| 1329 | + """ |
| 1330 | + from .visualizer_widget import VisualizerWidget |
| 1331 | + |
| 1332 | + def to_json() -> str: |
| 1333 | + return self._serialize(**kwargs).to_json() |
1352 | 1334 |
|
1353 | | - embed_minimal_html("tmp.html", views=[widget], drop_defaults=False) |
1354 | | - # Use panel while this issue persists |
1355 | | - # https://github.com/jupyter-widgets/ipywidgets/issues/3761#issuecomment-1755563436 |
| 1335 | + egraphs = [to_json()] |
| 1336 | + i = 0 |
| 1337 | + while self.run(schedule or 1).updated and i < max: |
| 1338 | + i += 1 |
| 1339 | + egraphs.append(to_json()) |
| 1340 | + VisualizerWidget(egraphs=egraphs).display_or_open() |
1356 | 1341 |
|
1357 | 1342 | @classmethod |
1358 | 1343 | def current(cls) -> EGraph: |
|
0 commit comments