Skip to content

Commit 8fe653c

Browse files
committed
Simplify code in printing.py and remove os.path
1 parent c0d5117 commit 8fe653c

File tree

1 file changed

+10
-14
lines changed

1 file changed

+10
-14
lines changed

pytensor/printing.py

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,14 @@
22

33
import hashlib
44
import logging
5-
import os
65
import sys
76
from abc import ABC, abstractmethod
87
from collections.abc import Callable, Sequence
98
from contextlib import contextmanager
109
from copy import copy
1110
from functools import reduce, singledispatch
1211
from io import StringIO
12+
from pathlib import Path
1313
from typing import Any, Literal, TextIO
1414

1515
import numpy as np
@@ -1200,7 +1200,7 @@ def __call__(self, *args):
12001200

12011201
def pydotprint(
12021202
fct,
1203-
outfile: str | None = None,
1203+
outfile: Path | str | None = None,
12041204
compact: bool = True,
12051205
format: str = "png",
12061206
with_ids: bool = False,
@@ -1295,9 +1295,9 @@ def pydotprint(
12951295
colorCodes = default_colorCodes
12961296

12971297
if outfile is None:
1298-
outfile = os.path.join(
1299-
config.compiledir, "pytensor.pydotprint." + config.device + "." + format
1300-
)
1298+
outfile = config.compiledir / f"pytensor.pydotprint.{config.device}.{format}"
1299+
elif isinstance(outfile, str):
1300+
outfile = Path(outfile)
13011301

13021302
if isinstance(fct, Function):
13031303
profile = getattr(fct, "profile", None)
@@ -1606,23 +1606,19 @@ def apply_name(node):
16061606
g.add_subgraph(c2)
16071607
g.add_subgraph(c3)
16081608

1609-
if not outfile.endswith("." + format):
1610-
outfile += "." + format
1609+
if outfile.suffix != f".{format}":
1610+
outfile = outfile.with_suffix(f".{format}")
16111611

16121612
if scan_graphs:
16131613
scan_ops = [(idx, x) for idx, x in enumerate(topo) if isinstance(x.op, Scan)]
1614-
path, fn = os.path.split(outfile)
1615-
basename = ".".join(fn.split(".")[:-1])
1616-
# Safe way of doing things .. a file name may contain multiple .
1617-
ext = fn[len(basename) :]
16181614

16191615
for idx, scan_op in scan_ops:
16201616
# is there a chance that name is not defined?
16211617
if hasattr(scan_op.op, "name"):
1622-
new_name = basename + "_" + scan_op.op.name + "_" + str(idx)
1618+
new_name = outfile.stem + "_" + scan_op.op.name + "_" + str(idx)
16231619
else:
1624-
new_name = basename + "_" + str(idx)
1625-
new_name = os.path.join(path, new_name + ext)
1620+
new_name = outfile.stem + "_" + str(idx)
1621+
new_name = outfile.with_stem(new_name)
16261622
if hasattr(scan_op.op, "_fn"):
16271623
to_print = scan_op.op.fn
16281624
else:

0 commit comments

Comments
 (0)