Skip to content

Commit 175b67b

Browse files
committed
Lazy pydot import
1 parent 3409264 commit 175b67b

File tree

6 files changed

+76
-56
lines changed

6 files changed

+76
-56
lines changed

pytensor/d3viz/formatting.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,7 @@
1212
from pytensor.compile import Function, builders
1313
from pytensor.graph.basic import Apply, Constant, Variable, graph_inputs
1414
from pytensor.graph.fg import FunctionGraph
15-
from pytensor.printing import pydot_imported, pydot_imported_msg
16-
17-
18-
try:
19-
from pytensor.printing import pd
20-
except ImportError:
21-
pass
15+
from pytensor.printing import _try_pydot_import
2216

2317

2418
class PyDotFormatter:
@@ -41,8 +35,7 @@ class PyDotFormatter:
4135

4236
def __init__(self, compact=True):
4337
"""Construct PyDotFormatter object."""
44-
if not pydot_imported:
45-
raise ImportError("Failed to import pydot. " + pydot_imported_msg)
38+
_try_pydot_import()
4639

4740
self.compact = compact
4841
self.node_colors = {
@@ -115,6 +108,8 @@ def __call__(self, fct, graph=None):
115108
pydot.Dot
116109
Pydot graph of `fct`
117110
"""
111+
pd = _try_pydot_import()
112+
118113
if graph is None:
119114
graph = pd.Dot()
120115

@@ -356,6 +351,8 @@ def type_to_str(t):
356351

357352
def dict_to_pdnode(d):
358353
"""Create pydot node from dict."""
354+
pd = _try_pydot_import()
355+
359356
e = dict()
360357
for k, v in d.items():
361358
if v is not None:

pytensor/printing.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -26,39 +26,6 @@
2626

2727
IDTypesType = Literal["id", "int", "CHAR", "auto", ""]
2828

29-
pydot_imported = False
30-
pydot_imported_msg = ""
31-
try:
32-
# pydot-ng is a fork of pydot that is better maintained
33-
import pydot_ng as pd
34-
35-
if pd.find_graphviz():
36-
pydot_imported = True
37-
else:
38-
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
39-
except ImportError:
40-
try:
41-
# fall back on pydot if necessary
42-
import pydot as pd
43-
44-
if hasattr(pd, "find_graphviz"):
45-
if pd.find_graphviz():
46-
pydot_imported = True
47-
else:
48-
pydot_imported_msg = "pydot can't find graphviz"
49-
else:
50-
pd.Dot.create(pd.Dot())
51-
pydot_imported = True
52-
except ImportError:
53-
# tests should not fail on optional dependency
54-
pydot_imported_msg = (
55-
"Install the python package pydot or pydot-ng. Install graphviz."
56-
)
57-
except Exception as e:
58-
pydot_imported_msg = "An error happened while importing/trying pydot: "
59-
pydot_imported_msg += str(e.args)
60-
61-
6229
_logger = logging.getLogger("pytensor.printing")
6330
VALID_ASSOC = {"left", "right", "either"}
6431

@@ -1196,6 +1163,48 @@ def __call__(self, *args):
11961163
}
11971164

11981165

1166+
def _try_pydot_import():
1167+
pydot_imported = False
1168+
pydot_imported_msg = ""
1169+
try:
1170+
# pydot-ng is a fork of pydot that is better maintained
1171+
import pydot_ng as pd
1172+
1173+
if pd.find_graphviz():
1174+
pydot_imported = True
1175+
else:
1176+
pydot_imported_msg = "pydot-ng can't find graphviz. Install graphviz."
1177+
except ImportError:
1178+
try:
1179+
# fall back on pydot if necessary
1180+
import pydot as pd
1181+
1182+
if hasattr(pd, "find_graphviz"):
1183+
if pd.find_graphviz():
1184+
pydot_imported = True
1185+
else:
1186+
pydot_imported_msg = "pydot can't find graphviz"
1187+
else:
1188+
pd.Dot.create(pd.Dot())
1189+
pydot_imported = True
1190+
except ImportError:
1191+
# tests should not fail on optional dependency
1192+
pydot_imported_msg = (
1193+
"Install the python package pydot or pydot-ng. Install graphviz."
1194+
)
1195+
except Exception as e:
1196+
pydot_imported_msg = "An error happened while importing/trying pydot: "
1197+
pydot_imported_msg += str(e.args)
1198+
1199+
if not pydot_imported:
1200+
raise ImportError(
1201+
"Failed to import pydot. You must install graphviz "
1202+
"and either pydot or pydot-ng for "
1203+
f"`pydotprint` to work:\n {pydot_imported_msg}",
1204+
)
1205+
return pd
1206+
1207+
11991208
def pydotprint(
12001209
fct,
12011210
outfile: Path | str | None = None,
@@ -1288,6 +1297,8 @@ def pydotprint(
12881297
scan separately after the top level debugprint output.
12891298
12901299
"""
1300+
pd = _try_pydot_import()
1301+
12911302
from pytensor.scan.op import Scan
12921303

12931304
if colorCodes is None:
@@ -1320,12 +1331,6 @@ def pydotprint(
13201331
outputs = fct.outputs
13211332
topo = fct.toposort()
13221333
fgraph = fct
1323-
if not pydot_imported:
1324-
raise RuntimeError(
1325-
"Failed to import pydot. You must install graphviz "
1326-
"and either pydot or pydot-ng for "
1327-
f"`pydotprint` to work:\n {pydot_imported_msg}",
1328-
)
13291334

13301335
g = pd.Dot()
13311336

tests/d3viz/test_d3viz.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
from pytensor import compile
1010
from pytensor.compile.function import function
1111
from pytensor.configdefaults import config
12-
from pytensor.printing import pydot_imported, pydot_imported_msg
12+
from pytensor.printing import _try_pydot_import
1313
from tests.d3viz import models
1414

1515

16-
if not pydot_imported:
17-
pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True)
16+
try:
17+
_try_pydot_import()
18+
except Exception as e:
19+
pytest.skip(f"pydot not available: {e!s}", allow_module_level=True)
1820

1921

2022
class TestD3Viz:

tests/d3viz/test_formatting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,11 +3,13 @@
33

44
from pytensor import config, function
55
from pytensor.d3viz.formatting import PyDotFormatter
6-
from pytensor.printing import pydot_imported, pydot_imported_msg
6+
from pytensor.printing import _try_pydot_import
77

88

9-
if not pydot_imported:
10-
pytest.skip("pydot not available: " + pydot_imported_msg, allow_module_level=True)
9+
try:
10+
_try_pydot_import()
11+
except Exception as e:
12+
pytest.skip(f"pydot not available: {e!s}", allow_module_level=True)
1113

1214
from tests.d3viz import models
1315

tests/scan/test_printing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import pytensor.tensor as pt
66
from pytensor.configdefaults import config
77
from pytensor.graph.fg import FunctionGraph
8-
from pytensor.printing import debugprint, pydot_imported, pydotprint
8+
from pytensor.printing import _try_pydot_import, debugprint, pydotprint
99
from pytensor.tensor.type import dvector, iscalar, scalar, vector
1010

1111

@@ -686,6 +686,13 @@ def no_shared_fn(n, x_tm1, M):
686686
assert truth.strip() == out.strip()
687687

688688

689+
try:
690+
_try_pydot_import()
691+
pydot_imported = True
692+
except Exception:
693+
pydot_imported = False
694+
695+
689696
@pytest.mark.skipif(not pydot_imported, reason="pydot not available")
690697
def test_pydotprint():
691698
def f_pow2(x_tm1):

tests/test_printing.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,20 +17,27 @@
1717
PatternPrinter,
1818
PPrinter,
1919
Print,
20+
_try_pydot_import,
2021
char_from_number,
2122
debugprint,
2223
default_printer,
2324
get_node_by_id,
2425
min_informative_str,
2526
pp,
26-
pydot_imported,
2727
pydotprint,
2828
)
2929
from pytensor.tensor import as_tensor_variable
3030
from pytensor.tensor.type import dmatrix, dvector, matrix
3131
from tests.graph.utils import MyInnerGraphOp, MyOp, MyVariable
3232

3333

34+
try:
35+
_try_pydot_import()
36+
pydot_imported = True
37+
except Exception:
38+
pydot_imported = False
39+
40+
3441
@pytest.mark.parametrize(
3542
"number,s",
3643
[

0 commit comments

Comments
 (0)