|
26 | 26 |
|
27 | 27 | IDTypesType = Literal["id", "int", "CHAR", "auto", ""] |
28 | 28 |
|
29 | | -pydot_imported = False |
30 | | -pydot_imported_msg = "" |
31 | | -try: |
32 | | - # pydot-ng is a fork of pydot that is better maintained |
33 | | - import pydot_ng as pd |
34 | | - |
35 | | - if pd.find_graphviz(): |
36 | | - pydot_imported = True |
37 | | - else: |
38 | | - pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." |
39 | | -except ImportError: |
40 | | - try: |
41 | | - # fall back on pydot if necessary |
42 | | - import pydot as pd |
43 | | - |
44 | | - if hasattr(pd, "find_graphviz"): |
45 | | - if pd.find_graphviz(): |
46 | | - pydot_imported = True |
47 | | - else: |
48 | | - pydot_imported_msg = "pydot can't find graphviz" |
49 | | - else: |
50 | | - pd.Dot.create(pd.Dot()) |
51 | | - pydot_imported = True |
52 | | - except ImportError: |
53 | | - # tests should not fail on optional dependency |
54 | | - pydot_imported_msg = ( |
55 | | - "Install the python package pydot or pydot-ng. Install graphviz." |
56 | | - ) |
57 | | - except Exception as e: |
58 | | - pydot_imported_msg = "An error happened while importing/trying pydot: " |
59 | | - pydot_imported_msg += str(e.args) |
60 | | - |
61 | | - |
62 | 29 | _logger = logging.getLogger("pytensor.printing") |
63 | 30 | VALID_ASSOC = {"left", "right", "either"} |
64 | 31 |
|
@@ -1196,6 +1163,48 @@ def __call__(self, *args): |
1196 | 1163 | } |
1197 | 1164 |
|
1198 | 1165 |
|
| 1166 | +def _try_pydot_import(): |
| 1167 | + pydot_imported = False |
| 1168 | + pydot_imported_msg = "" |
| 1169 | + try: |
| 1170 | + # pydot-ng is a fork of pydot that is better maintained |
| 1171 | + import pydot_ng as pd |
| 1172 | + |
| 1173 | + if pd.find_graphviz(): |
| 1174 | + pydot_imported = True |
| 1175 | + else: |
| 1176 | + pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." |
| 1177 | + except ImportError: |
| 1178 | + try: |
| 1179 | + # fall back on pydot if necessary |
| 1180 | + import pydot as pd |
| 1181 | + |
| 1182 | + if hasattr(pd, "find_graphviz"): |
| 1183 | + if pd.find_graphviz(): |
| 1184 | + pydot_imported = True |
| 1185 | + else: |
| 1186 | + pydot_imported_msg = "pydot can't find graphviz" |
| 1187 | + else: |
| 1188 | + pd.Dot.create(pd.Dot()) |
| 1189 | + pydot_imported = True |
| 1190 | + except ImportError: |
| 1191 | + # tests should not fail on optional dependency |
| 1192 | + pydot_imported_msg = ( |
| 1193 | + "Install the python package pydot or pydot-ng. Install graphviz." |
| 1194 | + ) |
| 1195 | + except Exception as e: |
| 1196 | + pydot_imported_msg = "An error happened while importing/trying pydot: " |
| 1197 | + pydot_imported_msg += str(e.args) |
| 1198 | + |
| 1199 | + if not pydot_imported: |
| 1200 | + raise ImportError( |
| 1201 | + "Failed to import pydot. You must install graphviz " |
| 1202 | + "and either pydot or pydot-ng for " |
| 1203 | + f"`pydotprint` to work:\n {pydot_imported_msg}", |
| 1204 | + ) |
| 1205 | + return pd |
| 1206 | + |
| 1207 | + |
1199 | 1208 | def pydotprint( |
1200 | 1209 | fct, |
1201 | 1210 | outfile: Path | str | None = None, |
@@ -1288,6 +1297,8 @@ def pydotprint( |
1288 | 1297 | scan separately after the top level debugprint output. |
1289 | 1298 |
|
1290 | 1299 | """ |
| 1300 | + pd = _try_pydot_import() |
| 1301 | + |
1291 | 1302 | from pytensor.scan.op import Scan |
1292 | 1303 |
|
1293 | 1304 | if colorCodes is None: |
@@ -1320,12 +1331,6 @@ def pydotprint( |
1320 | 1331 | outputs = fct.outputs |
1321 | 1332 | topo = fct.toposort() |
1322 | 1333 | fgraph = fct |
1323 | | - if not pydot_imported: |
1324 | | - raise RuntimeError( |
1325 | | - "Failed to import pydot. You must install graphviz " |
1326 | | - "and either pydot or pydot-ng for " |
1327 | | - f"`pydotprint` to work:\n {pydot_imported_msg}", |
1328 | | - ) |
1329 | 1334 |
|
1330 | 1335 | g = pd.Dot() |
1331 | 1336 |
|
|
0 commit comments