4
4
"""
5
5
6
6
import json
7
- import os
8
7
import shutil
8
+ from pathlib import Path
9
9
10
10
from pytensor .d3viz .formatting import PyDotFormatter
11
11
12
12
13
- __path__ = os . path . dirname ( os . path . realpath ( __file__ ))
13
+ __path__ = Path ( __file__ ). parent
14
14
15
15
16
16
def replace_patterns (x , replace ):
@@ -40,7 +40,7 @@ def safe_json(obj):
40
40
return json .dumps (obj ).replace ("<" , "\\ u003c" )
41
41
42
42
43
- def d3viz (fct , outfile , copy_deps = True , * args , ** kwargs ):
43
+ def d3viz (fct , outfile : Path | str , copy_deps : bool = True , * args , ** kwargs ):
44
44
"""Create HTML file with dynamic visualizing of an PyTensor function graph.
45
45
46
46
In the HTML file, the whole graph or single nodes can be moved by drag and
@@ -59,7 +59,7 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
59
59
----------
60
60
fct : pytensor.compile.function.types.Function
61
61
A compiled PyTensor function, variable, apply or a list of variables.
62
- outfile : str
62
+ outfile : Path | str
63
63
Path to output HTML file.
64
64
copy_deps : bool, optional
65
65
Copy javascript and CSS dependencies to output directory.
@@ -78,30 +78,28 @@ def d3viz(fct, outfile, copy_deps=True, *args, **kwargs):
78
78
dot_graph = dot_graph .decode ("utf8" )
79
79
80
80
# Create output directory if not existing
81
- outdir = os .path .dirname (outfile )
82
- if outdir != "" and not os .path .exists (outdir ):
83
- os .makedirs (outdir )
81
+ outdir = Path (outfile ).parent
82
+ outdir .mkdir (exist_ok = True )
84
83
85
84
# Read template HTML file
86
- template_file = os .path .join (__path__ , "html" , "template.html" )
87
- with open (template_file ) as f :
88
- template = f .read ()
85
+ template_file = __path__ / "html/template.html"
86
+ template = template_file .read_text (encoding = "utf-8" )
89
87
90
88
# Copy dependencies to output directory
91
89
src_deps = __path__
92
90
if copy_deps :
93
- dst_deps = "d3viz"
91
+ dst_deps = outdir / "d3viz"
94
92
for d in ("js" , "css" ):
95
- dep = os . path . join ( outdir , dst_deps , d )
96
- if not os . path . exists (dep ):
97
- shutil .copytree (os . path . join ( src_deps , d ) , dep )
93
+ dep = dst_deps / d
94
+ if not dep . exists ():
95
+ shutil .copytree (src_deps / d , dep )
98
96
else :
99
97
dst_deps = src_deps
100
98
101
99
# Replace patterns in template
102
100
replace = {
103
- "%% JS_DIR %%" : os . path . join ( dst_deps , "js" ) ,
104
- "%% CSS_DIR %%" : os . path . join ( dst_deps , "css" ) ,
101
+ "%% JS_DIR %%" : dst_deps / "js" ,
102
+ "%% CSS_DIR %%" : dst_deps / "css" ,
105
103
"%% DOT_GRAPH %%" : safe_json (dot_graph ),
106
104
}
107
105
html = replace_patterns (template , replace )
0 commit comments