|
26 | 26 |
|
27 | 27 | IDTypesType = Literal["id", "int", "CHAR", "auto", ""] |
28 | 28 |
|
29 | | -pydot_imported = False |
30 | | -pydot_imported_msg = "" |
31 | | -try: |
32 | | - # pydot-ng is a fork of pydot that is better maintained |
33 | | - import pydot_ng as pd |
34 | | - |
35 | | - if pd.find_graphviz(): |
36 | | - pydot_imported = True |
37 | | - else: |
38 | | - pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." |
39 | | -except ImportError: |
40 | | - try: |
41 | | - # fall back on pydot if necessary |
42 | | - import pydot as pd |
43 | | - |
44 | | - if hasattr(pd, "find_graphviz"): |
45 | | - if pd.find_graphviz(): |
46 | | - pydot_imported = True |
47 | | - else: |
48 | | - pydot_imported_msg = "pydot can't find graphviz" |
49 | | - else: |
50 | | - pd.Dot.create(pd.Dot()) |
51 | | - pydot_imported = True |
52 | | - except ImportError: |
53 | | - # tests should not fail on optional dependency |
54 | | - pydot_imported_msg = ( |
55 | | - "Install the python package pydot or pydot-ng. Install graphviz." |
56 | | - ) |
57 | | - except Exception as e: |
58 | | - pydot_imported_msg = "An error happened while importing/trying pydot: " |
59 | | - pydot_imported_msg += str(e.args) |
60 | | - |
61 | | - |
62 | 29 | _logger = logging.getLogger("pytensor.printing") |
63 | 30 | VALID_ASSOC = {"left", "right", "either"} |
64 | 31 |
|
@@ -1288,6 +1255,45 @@ def pydotprint( |
1288 | 1255 | scan separately after the top level debugprint output. |
1289 | 1256 |
|
1290 | 1257 | """ |
| 1258 | + pydot_imported = False |
| 1259 | + pydot_imported_msg = "" |
| 1260 | + try: |
| 1261 | + # pydot-ng is a fork of pydot that is better maintained |
| 1262 | + import pydot_ng as pd |
| 1263 | + |
| 1264 | + if pd.find_graphviz(): |
| 1265 | + pydot_imported = True |
| 1266 | + else: |
| 1267 | + pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz." |
| 1268 | + except ImportError: |
| 1269 | + try: |
| 1270 | + # fall back on pydot if necessary |
| 1271 | + import pydot as pd |
| 1272 | + |
| 1273 | + if hasattr(pd, "find_graphviz"): |
| 1274 | + if pd.find_graphviz(): |
| 1275 | + pydot_imported = True |
| 1276 | + else: |
| 1277 | + pydot_imported_msg = "pydot can't find graphviz" |
| 1278 | + else: |
| 1279 | + pd.Dot.create(pd.Dot()) |
| 1280 | + pydot_imported = True |
| 1281 | + except ImportError: |
| 1282 | + # tests should not fail on optional dependency |
| 1283 | + pydot_imported_msg = ( |
| 1284 | + "Install the python package pydot or pydot-ng. Install graphviz." |
| 1285 | + ) |
| 1286 | + except Exception as e: |
| 1287 | + pydot_imported_msg = "An error happened while importing/trying pydot: " |
| 1288 | + pydot_imported_msg += str(e.args) |
| 1289 | + |
| 1290 | + if not pydot_imported: |
| 1291 | + raise RuntimeError( |
| 1292 | + "Failed to import pydot. You must install graphviz " |
| 1293 | + "and either pydot or pydot-ng for " |
| 1294 | + f"`pydotprint` to work:\n {pydot_imported_msg}", |
| 1295 | + ) |
| 1296 | + |
1291 | 1297 | from pytensor.scan.op import Scan |
1292 | 1298 |
|
1293 | 1299 | if colorCodes is None: |
@@ -1320,12 +1326,6 @@ def pydotprint( |
1320 | 1326 | outputs = fct.outputs |
1321 | 1327 | topo = fct.toposort() |
1322 | 1328 | fgraph = fct |
1323 | | - if not pydot_imported: |
1324 | | - raise RuntimeError( |
1325 | | - "Failed to import pydot. You must install graphviz " |
1326 | | - "and either pydot or pydot-ng for " |
1327 | | - f"`pydotprint` to work:\n {pydot_imported_msg}", |
1328 | | - ) |
1329 | 1329 |
|
1330 | 1330 | g = pd.Dot() |
1331 | 1331 |
|
|
0 commit comments