|
2 | 2 |
|
3 | 3 | import hashlib
|
4 | 4 | import logging
|
5 |
| -import os |
6 | 5 | import sys
|
7 | 6 | from abc import ABC, abstractmethod
|
8 | 7 | from collections.abc import Callable, Sequence
|
9 | 8 | from contextlib import contextmanager
|
10 | 9 | from copy import copy
|
11 | 10 | from functools import reduce, singledispatch
|
12 | 11 | from io import StringIO
|
| 12 | +from pathlib import Path |
13 | 13 | from typing import Any, Literal, TextIO
|
14 | 14 |
|
15 | 15 | import numpy as np
|
@@ -1200,7 +1200,7 @@ def __call__(self, *args):
|
1200 | 1200 |
|
1201 | 1201 | def pydotprint(
|
1202 | 1202 | fct,
|
1203 |
| - outfile: str | None = None, |
| 1203 | + outfile: Path | str | None = None, |
1204 | 1204 | compact: bool = True,
|
1205 | 1205 | format: str = "png",
|
1206 | 1206 | with_ids: bool = False,
|
@@ -1295,9 +1295,9 @@ def pydotprint(
|
1295 | 1295 | colorCodes = default_colorCodes
|
1296 | 1296 |
|
1297 | 1297 | if outfile is None:
|
1298 |
| - outfile = os.path.join( |
1299 |
| - config.compiledir, "pytensor.pydotprint." + config.device + "." + format |
1300 |
| - ) |
| 1298 | + outfile = config.compiledir / f"pytensor.pydotprint.{config.device}.{format}" |
| 1299 | + elif isinstance(outfile, str): |
| 1300 | + outfile = Path(outfile) |
1301 | 1301 |
|
1302 | 1302 | if isinstance(fct, Function):
|
1303 | 1303 | profile = getattr(fct, "profile", None)
|
@@ -1606,23 +1606,19 @@ def apply_name(node):
|
1606 | 1606 | g.add_subgraph(c2)
|
1607 | 1607 | g.add_subgraph(c3)
|
1608 | 1608 |
|
1609 |
| - if not outfile.endswith("." + format): |
1610 |
| - outfile += "." + format |
| 1609 | + if outfile.suffix != f".{format}": |
| 1610 | + outfile = outfile.with_suffix(f".{format}") |
1611 | 1611 |
|
1612 | 1612 | if scan_graphs:
|
1613 | 1613 | scan_ops = [(idx, x) for idx, x in enumerate(topo) if isinstance(x.op, Scan)]
|
1614 |
| - path, fn = os.path.split(outfile) |
1615 |
| - basename = ".".join(fn.split(".")[:-1]) |
1616 |
| - # Safe way of doing things .. a file name may contain multiple . |
1617 |
| - ext = fn[len(basename) :] |
1618 | 1614 |
|
1619 | 1615 | for idx, scan_op in scan_ops:
|
1620 | 1616 | # is there a chance that name is not defined?
|
1621 | 1617 | if hasattr(scan_op.op, "name"):
|
1622 |
| - new_name = basename + "_" + scan_op.op.name + "_" + str(idx) |
| 1618 | + new_name = outfile.stem + "_" + scan_op.op.name + "_" + str(idx) |
1623 | 1619 | else:
|
1624 |
| - new_name = basename + "_" + str(idx) |
1625 |
| - new_name = os.path.join(path, new_name + ext) |
| 1620 | + new_name = outfile.stem + "_" + str(idx) |
| 1621 | + new_name = outfile.with_stem(new_name) |
1626 | 1622 | if hasattr(scan_op.op, "_fn"):
|
1627 | 1623 | to_print = scan_op.op.fn
|
1628 | 1624 | else:
|
|
0 commit comments