Skip to content

Commit 1183a38

Browse files
committed
Lazy pydot import
1 parent f1f9905 commit 1183a38

File tree

1 file changed

+39
-39
lines changed

1 file changed

+39
-39
lines changed

pytensor/printing.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,6 @@
2626

2727
IDTypesType = Literal["id", "int", "CHAR", "auto", ""]
2828

29-
pydot_imported = False
30-
pydot_imported_msg = ""
31-
try:
32-
# pydot-ng is a fork of pydot that is better maintained
33-
import pydot_ng as pd
34-
35-
if pd.find_graphviz():
36-
pydot_imported = True
37-
else:
38-
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
39-
except ImportError:
40-
try:
41-
# fall back on pydot if necessary
42-
import pydot as pd
43-
44-
if hasattr(pd, "find_graphviz"):
45-
if pd.find_graphviz():
46-
pydot_imported = True
47-
else:
48-
pydot_imported_msg = "pydot can't find graphviz"
49-
else:
50-
pd.Dot.create(pd.Dot())
51-
pydot_imported = True
52-
except ImportError:
53-
# tests should not fail on optional dependency
54-
pydot_imported_msg = (
55-
"Install the python package pydot or pydot-ng. Install graphviz."
56-
)
57-
except Exception as e:
58-
pydot_imported_msg = "An error happened while importing/trying pydot: "
59-
pydot_imported_msg += str(e.args)
60-
61-
6229
_logger = logging.getLogger("pytensor.printing")
6330
VALID_ASSOC = {"left", "right", "either"}
6431

@@ -1288,6 +1255,45 @@ def pydotprint(
12881255
scan separately after the top level debugprint output.
12891256
12901257
"""
1258+
pydot_imported = False
1259+
pydot_imported_msg = ""
1260+
try:
1261+
# pydot-ng is a fork of pydot that is better maintained
1262+
import pydot_ng as pd
1263+
1264+
if pd.find_graphviz():
1265+
pydot_imported = True
1266+
else:
1267+
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
1268+
except ImportError:
1269+
try:
1270+
# fall back on pydot if necessary
1271+
import pydot as pd
1272+
1273+
if hasattr(pd, "find_graphviz"):
1274+
if pd.find_graphviz():
1275+
pydot_imported = True
1276+
else:
1277+
pydot_imported_msg = "pydot can't find graphviz"
1278+
else:
1279+
pd.Dot.create(pd.Dot())
1280+
pydot_imported = True
1281+
except ImportError:
1282+
# tests should not fail on optional dependency
1283+
pydot_imported_msg = (
1284+
"Install the python package pydot or pydot-ng. Install graphviz."
1285+
)
1286+
except Exception as e:
1287+
pydot_imported_msg = "An error happened while importing/trying pydot: "
1288+
pydot_imported_msg += str(e.args)
1289+
1290+
if not pydot_imported:
1291+
raise RuntimeError(
1292+
"Failed to import pydot. You must install graphviz "
1293+
"and either pydot or pydot-ng for "
1294+
f"`pydotprint` to work:\n {pydot_imported_msg}",
1295+
)
1296+
12911297
from pytensor.scan.op import Scan
12921298

12931299
if colorCodes is None:
@@ -1320,12 +1326,6 @@ def pydotprint(
13201326
outputs = fct.outputs
13211327
topo = fct.toposort()
13221328
fgraph = fct
1323-
if not pydot_imported:
1324-
raise RuntimeError(
1325-
"Failed to import pydot. You must install graphviz "
1326-
"and either pydot or pydot-ng for "
1327-
f"`pydotprint` to work:\n {pydot_imported_msg}",
1328-
)
13291329

13301330
g = pd.Dot()
13311331

0 commit comments

Comments
 (0)