Skip to content

Commit 93f4ce0

Browse files
committed
Remove os.path in d3viz.py
1 parent 8fd9f82 commit 93f4ce0

File tree

1 file changed

+14
-16
lines changed

1 file changed

+14
-16
lines changed

pytensor/d3viz/d3viz.py

Lines changed: 14 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@
44
"""
55

66
import json
7-
import os
87
import shutil
8+
from pathlib import Path
99

1010
from pytensor.d3viz.formatting import PyDotFormatter
1111

1212

13-
__path__ = os.path.dirname(os.path.realpath(__file__))
13+
__path__ = Path(__file__).parent
1414

1515

1616
def replace_patterns(x, replace):
@@ -40,7 +40,7 @@ def safe_json(obj):
4040
return json.dumps(obj).replace("<", "\\u003c")
4141

4242

43-
def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
43+
def d3viz(fct, outfile: Path | str, copy_deps: bool = True, *args, **kwargs):
4444
"""Create HTML file with dynamic visualizing of an PyTensor function graph.
4545
4646
In the HTML file, the whole graph or single nodes can be moved by drag and
@@ -59,7 +59,7 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
5959
----------
6060
fct : pytensor.compile.function.types.Function
6161
A compiled PyTensor function, variable, apply or a list of variables.
62-
outfile : str
62+
outfile : Path | str
6363
Path to output HTML file.
6464
copy_deps : bool, optional
6565
Copy javascript and CSS dependencies to output directory.
@@ -78,30 +78,28 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
7878
dot_graph = dot_graph.decode("utf8")
7979

8080
# Create output directory if not existing
81-
outdir = os.path.dirname(outfile)
82-
if outdir != "" and not os.path.exists(outdir):
83-
os.makedirs(outdir)
81+
outdir = Path(outfile).parent
82+
outdir.mkdir(exist_ok=True)
8483

8584
# Read template HTML file
86-
template_file = os.path.join(__path__, "html", "template.html")
87-
with open(template_file) as f:
88-
template = f.read()
85+
template_file = __path__ / "html/template.html"
86+
template = template_file.read_text(encoding="utf-8")
8987

9088
# Copy dependencies to output directory
9189
src_deps = __path__
9290
if copy_deps:
93-
dst_deps = "d3viz"
91+
dst_deps = outdir / "d3viz"
9492
for d in ("js", "css"):
95-
dep = os.path.join(outdir, dst_deps, d)
96-
if not os.path.exists(dep):
97-
shutil.copytree(os.path.join(src_deps, d), dep)
93+
dep = dst_deps / d
94+
if not dep.exists():
95+
shutil.copytree(src_deps / d, dep)
9896
else:
9997
dst_deps = src_deps
10098

10199
# Replace patterns in template
102100
replace = {
103-
"%% JS_DIR %%": os.path.join(dst_deps, "js"),
104-
"%% CSS_DIR %%": os.path.join(dst_deps, "css"),
101+
"%% JS_DIR %%": dst_deps / "js",
102+
"%% CSS_DIR %%": dst_deps / "css",
105103
"%% DOT_GRAPH %%": safe_json(dot_graph),
106104
}
107105
html = replace_patterns(template, replace)

0 commit comments

Comments
 (0)