|
26 | 26 |
|
27 | 27 | IDTypesType = Literal["id", "int", "CHAR", "auto", ""]
|
28 | 28 |
|
29 |
| -pydot_imported = False |
30 |
| -pydot_imported_msg = "" |
31 |
| -try: |
32 |
| - # pydot-ng is a fork of pydot that is better maintained |
33 |
| - import pydot_ng as pd |
34 |
| - |
35 |
| - if pd.find_graphviz(): |
36 |
| - pydot_imported = True |
37 |
| - else: |
38 |
| - pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." |
39 |
| -except ImportError: |
40 |
| - try: |
41 |
| - # fall back on pydot if necessary |
42 |
| - import pydot as pd |
43 |
| - |
44 |
| - if hasattr(pd, "find_graphviz"): |
45 |
| - if pd.find_graphviz(): |
46 |
| - pydot_imported = True |
47 |
| - else: |
48 |
| - pydot_imported_msg = "pydot can't find graphviz" |
49 |
| - else: |
50 |
| - pd.Dot.create(pd.Dot()) |
51 |
| - pydot_imported = True |
52 |
| - except ImportError: |
53 |
| - # tests should not fail on optional dependency |
54 |
| - pydot_imported_msg = ( |
55 |
| - "Install the python package pydot or pydot-ng. Install graphviz." |
56 |
| - ) |
57 |
| - except Exception as e: |
58 |
| - pydot_imported_msg = "An error happened while importing/trying pydot: " |
59 |
| - pydot_imported_msg += str(e.args) |
60 |
| - |
61 |
| - |
62 | 29 | _logger = logging.getLogger("pytensor.printing")
|
63 | 30 | VALID_ASSOC = {"left", "right", "either"}
|
64 | 31 |
|
@@ -1196,6 +1163,48 @@ def __call__(self, *args):
|
1196 | 1163 | }
|
1197 | 1164 |
|
1198 | 1165 |
|
| 1166 | +def _try_pydot_import(): |
| 1167 | + pydot_imported = False |
| 1168 | + pydot_imported_msg = "" |
| 1169 | + try: |
| 1170 | + # pydot-ng is a fork of pydot that is better maintained |
| 1171 | + import pydot_ng as pd |
| 1172 | + |
| 1173 | + if pd.find_graphviz(): |
| 1174 | + pydot_imported = True |
| 1175 | + else: |
| 1176 | + pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." |
| 1177 | + except ImportError: |
| 1178 | + try: |
| 1179 | + # fall back on pydot if necessary |
| 1180 | + import pydot as pd |
| 1181 | + |
| 1182 | + if hasattr(pd, "find_graphviz"): |
| 1183 | + if pd.find_graphviz(): |
| 1184 | + pydot_imported = True |
| 1185 | + else: |
| 1186 | + pydot_imported_msg = "pydot can't find graphviz" |
| 1187 | + else: |
| 1188 | + pd.Dot.create(pd.Dot()) |
| 1189 | + pydot_imported = True |
| 1190 | + except ImportError: |
| 1191 | + # tests should not fail on optional dependency |
| 1192 | + pydot_imported_msg = ( |
| 1193 | + "Install the python package pydot or pydot-ng. Install graphviz." |
| 1194 | + ) |
| 1195 | + except Exception as e: |
| 1196 | + pydot_imported_msg = "An error happened while importing/trying pydot: " |
| 1197 | + pydot_imported_msg += str(e.args) |
| 1198 | + |
| 1199 | + if not pydot_imported: |
| 1200 | + raise ImportError( |
| 1201 | + "Failed to import pydot. You must install graphviz " |
| 1202 | + "and either pydot or pydot-ng for " |
| 1203 | + f"`pydotprint` to work:\n {pydot_imported_msg}", |
| 1204 | + ) |
| 1205 | + return pd |
| 1206 | + |
| 1207 | + |
1199 | 1208 | def pydotprint(
|
1200 | 1209 | fct,
|
1201 | 1210 | outfile: Path | str | None = None,
|
@@ -1288,6 +1297,8 @@ def pydotprint(
|
1288 | 1297 | scan separately after the top level debugprint output.
|
1289 | 1298 |
|
1290 | 1299 | """
|
| 1300 | + pd = _try_pydot_import() |
| 1301 | + |
1291 | 1302 | from pytensor.scan.op import Scan
|
1292 | 1303 |
|
1293 | 1304 | if colorCodes is None:
|
@@ -1320,12 +1331,6 @@ def pydotprint(
|
1320 | 1331 | outputs = fct.outputs
|
1321 | 1332 | topo = fct.toposort()
|
1322 | 1333 | fgraph = fct
|
1323 |
| - if not pydot_imported: |
1324 |
| - raise RuntimeError( |
1325 |
| - "Failed to import pydot. You must install graphviz " |
1326 |
| - "and either pydot or pydot-ng for " |
1327 |
| - f"`pydotprint` to work:\n {pydot_imported_msg}", |
1328 |
| - ) |
1329 | 1334 |
|
1330 | 1335 | g = pd.Dot()
|
1331 | 1336 |
|
|
0 commit comments