Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: debug-statements
exclude: |
Expand All @@ -25,11 +25,11 @@ repos:
rev: v1.0.0
hooks:
- id: sphinx-lint
args: ["."]
args: ["-i", ".pixi", "."]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
rev: v0.14.0
hooks:
- id: ruff
- id: ruff-check
types_or: [python, pyi, jupyter]
args: ["--fix", "--output-format=full"]
- id: ruff-format
Expand Down
9 changes: 3 additions & 6 deletions pytensor/bin/pytensor_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import sys
from pathlib import Path


if sys.platform == "win32":
Expand All @@ -24,7 +25,7 @@

def print_help(exit_status):
if exit_status:
print(f"command \"{' '.join(sys.argv)}\" not recognized")
print(f'command "{" ".join(sys.argv)}" not recognized')
print('Type "pytensor-cache" to print the cache location')
print('Type "pytensor-cache help" to print this help')
print('Type "pytensor-cache clear" to erase the cache')
Expand Down Expand Up @@ -65,11 +66,7 @@ def main():
# Print a warning if some cached modules were not removed, so that the
# user knows he should manually delete them, or call
# pytensor-cache purge, # to properly clear the cache.
items = [
item
for item in sorted(os.listdir(cache.dirname))
if item.startswith("tmp")
]
items = list(Path(cache.dirname).glob("tmp*"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this sorted by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't matter though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not but yes it doesn't matter here

if items:
_logger.warning(
"There remain elements in the cache dir that you may "
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.graph import Variable


__all__ = ["types", "pfunc"]
__all__ = ["pfunc", "types"]

__docformat__ = "restructuredtext en"
_logger = logging.getLogger("pytensor.compile.function")
Expand Down
3 changes: 1 addition & 2 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ def clone_inputs(i):
cloned_outputs = [] # TODO: get Function.__call__ to return None
else:
raise TypeError(
"output must be an PyTensor Variable or Out "
"instance (or list of them)",
"output must be an PyTensor Variable or Out instance (or list of them)",
outputs,
)

Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def add_supervisor_to_fgraph(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
spec.mutable or (has_destroy_handler and fgraph.has_destroyers([input]))
)
)
)
Expand Down
3 changes: 1 addition & 2 deletions pytensor/compile/monitormode.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(
optimizer = config.optimizer
if linker is not None and not isinstance(linker.mode, MonitorMode):
raise Exception(
"MonitorMode can only use its own linker! You "
"should not provide one.",
"MonitorMode can only use its own linker! You should not provide one.",
linker,
)

Expand Down
61 changes: 25 additions & 36 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,12 @@ def summary_class(self, file=sys.stderr, N=None):
# While this carries over less information, it is arranged such
# that it is way more readable than the previous output of the
# profiler
nb_classes = max(0, len(otimes) - N)
percent_total = sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:])
time_total = sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])
print(
" ... (remaining %i Classes account for %6.2f%%(%.2fs) of "
"the runtime)"
% (
max(0, len(otimes) - N),
sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]),
sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:]),
),
f" ... (remaining {nb_classes} Classes account for "
f"{percent_total:6.2f}%%({time_total:.2f}s) of the runtime)",
file=file,
)
print("", file=file)
Expand Down Expand Up @@ -607,14 +605,12 @@ def summary_ops(self, file=sys.stderr, N=None):
# While this carries over less information, it is arranged such
# that it is way more readable than the previous output of the
# profiler
nb_ops = max(0, len(otimes) - N)
percent_total = sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:])
time_total = sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])
print(
" ... (remaining %i Ops account for %6.2f%%(%.2fs) of "
"the runtime)"
% (
max(0, len(otimes) - N),
sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]),
sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:]),
),
f" ... (remaining {nb_ops} Ops account for "
f"{percent_total:6.2f}%%({time_total:.2f}s) of the runtime)",
file=file,
)
print("", file=file)
Expand Down Expand Up @@ -935,18 +931,14 @@ def count_running_memory(order, fgraph, nodes_mem, ignore_dmap=False):
if dmap and idx2 in dmap:
vidx = dmap[idx2]
assert len(vidx) == 1, (
"Here we only support the "
"possibility to destroy one "
"input"
"Here we only support the possibility to destroy one input"
)
ins = node.inputs[vidx[0]]
if vmap and idx2 in vmap:
assert ins is None
vidx = vmap[idx2]
assert len(vidx) == 1, (
"Here we only support the "
"possibility to view one "
"input"
"Here we only support the possibility to view one input"
)
ins = node.inputs[vidx[0]]
if ins is not None:
Expand Down Expand Up @@ -1093,9 +1085,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
assert ins is None
vidx = vmap[idx]
assert len(vidx) == 1, (
"Here we only support "
"the possibility to "
"view one input"
"Here we only support the possibility to view one input"
)
ins = node.inputs[vidx[0]]
if ins is not None:
Expand Down Expand Up @@ -1304,22 +1294,22 @@ def print_stats(stats1, stats2):

print(
(
f" CPU: {int(round(new_max_running_max_memory_size[1] / 1024.0))}KB "
f"({int(round(max_running_max_memory_size[1] / 1024.0))}KB)"
f" CPU: {round(new_max_running_max_memory_size[1] / 1024.0)}KB "
f"({round(max_running_max_memory_size[1] / 1024.0)}KB)"
),
file=file,
)
print(
(
f" GPU: {int(round(new_max_running_max_memory_size[2] / 1024.0))}KB "
f"({int(round(max_running_max_memory_size[2] / 1024.0))}KB)"
f" GPU: {round(new_max_running_max_memory_size[2] / 1024.0)}KB "
f"({round(max_running_max_memory_size[2] / 1024.0)}KB)"
),
file=file,
)
print(
(
f" CPU + GPU: {int(round(new_max_running_max_memory_size[0] / 1024.0))}KB "
f"({int(round(max_running_max_memory_size[0] / 1024.0))}KB)"
f" CPU + GPU: {round(new_max_running_max_memory_size[0] / 1024.0)}KB "
f"({round(max_running_max_memory_size[0] / 1024.0)}KB)"
),
file=file,
)
Expand All @@ -1340,23 +1330,23 @@ def print_stats(stats1, stats2):
file=file,
)
print(
f" CPU: {int(round(new_max_node_memory_size[1] / 1024.0))}KB",
f" CPU: {round(new_max_node_memory_size[1] / 1024.0)}KB",
file=file,
)
print(
f" GPU: {int(round(new_max_node_memory_size[2] / 1024.0))}KB",
f" GPU: {round(new_max_node_memory_size[2] / 1024.0)}KB",
file=file,
)
print(
f" CPU + GPU: {int(round(new_max_node_memory_size[0] / 1024.0))}KB",
f" CPU + GPU: {round(new_max_node_memory_size[0] / 1024.0)}KB",
file=file,
)
print("---", file=file)

if min_max_peak:
print(
" Minimum peak from all valid apply node order is "
f"{int(round(min_max_peak / 1024.0))}KB(took {min_peak_time:3f}s to compute)",
f"{round(min_max_peak / 1024.0)}KB(took {min_peak_time:3f}s to compute)",
file=file,
)

Expand Down Expand Up @@ -1405,7 +1395,7 @@ def print_stats(stats1, stats2):
print(
(
f" ... (remaining {max(0, len(node_mem) - N)} Apply account for "
f"{sum_remaining:4d}B/{size_sum_dense :d}B ({p}) of the"
f"{sum_remaining:4d}B/{size_sum_dense:d}B ({p}) of the"
" Apply with dense outputs sizes)"
),
file=file,
Expand Down Expand Up @@ -1545,8 +1535,7 @@ def amdlibm_speed_up(op):
return True
elif s_op.__class__ not in scalar_op_amdlibm_no_speed_up:
print(
"We don't know if amdlibm will accelerate "
"this scalar op.",
"We don't know if amdlibm will accelerate this scalar op.",
s_op,
file=file,
)
Expand Down
3 changes: 1 addition & 2 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def _warn_cxx(val):
"""We only support clang++ as otherwise we hit strange g++/OSX bugs."""
if sys.platform == "darwin" and val and "clang++" not in val:
_logger.warning(
"Only clang++ is supported. With g++,"
" we end up with strange g++/OSX bugs."
"Only clang++ is supported. With g++, we end up with strange g++/OSX bugs."
)
return True

Expand Down
9 changes: 4 additions & 5 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,7 @@ def grad(

if cost is not None and isinstance(cost.type, NullType):
raise ValueError(
"Can't differentiate a NaN cost. "
f"Cost is NaN because {cost.type.why_null}"
f"Can't differentiate a NaN cost. Cost is NaN because {cost.type.why_null}"
)

if cost is not None and cost.type.ndim != 0:
Expand Down Expand Up @@ -2199,9 +2198,9 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
sequences=pytensor.tensor.arange(expr.shape[0]),
non_sequences=[expr, input],
)
assert (
not updates
), "Scan has returned a list of updates; this should not happen."
assert not updates, (
"Scan has returned a list of updates; this should not happen."
)
hessians.append(hess)
return as_list_or_tuple(using_list, using_tuple, hessians)

Expand Down
18 changes: 8 additions & 10 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,17 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):

c = a + b # create a simple expression

f = pytensor.function(
[b], [c]
) # this works because a has a value associated with it already
# this works because a has a value associated with it already
f = pytensor.function([b], [c])

assert 4.0 == f(2.5) # bind 2.5 to an internal copy of b and evaluate an internal c
# bind 2.5 to an internal copy of b and evaluate an internal c
assert 4.0 == f(2.5)

pytensor.function(
[a], [c]
) # compilation error because b (required by c) is undefined
# compilation error because b (required by c) is undefined
pytensor.function([a], [c])

pytensor.function(
[a, b], [c]
) # compilation error because a is constant, it can't be an input
# compilation error because a is constant, it can't be an input
pytensor.function([a, b], [c])


The python variables ``a, b, c`` all refer to instances of type
Expand Down
3 changes: 1 addition & 2 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,7 @@ def __init__(self):
def on_attach(self, fgraph):
if hasattr(fgraph, "checkpoint") or hasattr(fgraph, "revert"):
raise AlreadyThere(
"History feature is already present or in"
" conflict with another plugin."
"History feature is already present or in conflict with another plugin."
)
self.history[fgraph] = []
# Don't call unpickle here, as ReplaceValidate.on_attach()
Expand Down
5 changes: 1 addition & 4 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2975,10 +2975,7 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
raise ValueError("ops_to_check does not have the right type")

if not apply_nodes_to_check:
msg = (
"Provided op instances/classes are not in the graph or the "
"graph is empty"
)
msg = "Provided op instances/classes are not in the graph or the graph is empty"
if bug_print == "warn":
warnings.warn(msg)
elif bug_print == "raise":
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def query(self, *tags, position_cutoff: int | float | None = None, **kwtags):
return ret

def print_summary(self, stream=sys.stdout):
print(f"{self.__class__.__name__ } (id {id(self)})", file=stream)
print(f"{self.__class__.__name__} (id {id(self)})", file=stream)
positions = list(self.__position__.items())

def c(a, b):
Expand Down
15 changes: 7 additions & 8 deletions pytensor/link/c/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,12 +108,12 @@ def failure_code(sub, use_goto=True):
be careful to avoid executing incorrect code.

"""
id = sub["id"]
failure_var = sub["failure_var"]
if use_goto:
goto_statement = "goto __label_%(id)i;" % sub
goto_statement = f"goto __label_{id};"
else:
goto_statement = ""
id = sub["id"]
failure_var = sub["failure_var"]
return f"""{{
{failure_var} = {id};
if (!PyErr_Occurred()) {{
Expand Down Expand Up @@ -821,9 +821,9 @@ def code_gen(self):

behavior = op.c_code(node, name, isyms, osyms, sub)

assert isinstance(
behavior, str
), f"{node.op} didn't return a string for c_code"
assert isinstance(behavior, str), (
f"{node.op} didn't return a string for c_code"
)
# To help understand what is following. It help read the c code.
# This prevent different op that generate the same c code
# to be merged, I suppose this won't happen...
Expand Down Expand Up @@ -1753,8 +1753,7 @@ def __call__(self):
except Exception:
print( # noqa: T201
(
"ERROR retrieving error_storage."
"Was the error set in the c code?"
"ERROR retrieving error_storage. Was the error set in the c code?"
),
end=" ",
file=sys.stderr,
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/c/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -588,7 +588,7 @@ def c_code(self, node, name, inp, out, sub):
{define_macros}
{{
if ({self.func_name}({self.format_c_function_args(inp, out)}{params}) != 0) {{
{sub['fail']}
{sub["fail"]}
}}
}}
{undef_macros}
Expand Down
Loading