Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ exclude: |
)$
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v5.0.0
rev: v6.0.0
hooks:
- id: debug-statements
exclude: |
Expand All @@ -25,11 +25,11 @@ repos:
rev: v1.0.0
hooks:
- id: sphinx-lint
args: ["."]
args: ["-i", ".pixi", "."]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.3
rev: v0.14.0
hooks:
- id: ruff
- id: ruff-check
types_or: [python, pyi, jupyter]
args: ["--fix", "--output-format=full"]
- id: ruff-format
Expand Down
9 changes: 3 additions & 6 deletions pytensor/bin/pytensor_cache.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import logging
import os
import sys
from pathlib import Path


if sys.platform == "win32":
Expand All @@ -24,7 +25,7 @@

def print_help(exit_status):
if exit_status:
print(f"command \"{' '.join(sys.argv)}\" not recognized")
print(f'command "{" ".join(sys.argv)}" not recognized')
print('Type "pytensor-cache" to print the cache location')
print('Type "pytensor-cache help" to print this help')
print('Type "pytensor-cache clear" to erase the cache')
Expand Down Expand Up @@ -65,11 +66,7 @@ def main():
# Print a warning if some cached modules were not removed, so that the
# user knows he should manually delete them, or call
# pytensor-cache purge, # to properly clear the cache.
items = [
item
for item in sorted(os.listdir(cache.dirname))
if item.startswith("tmp")
]
items = list(Path(cache.dirname).glob("tmp*"))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this sorted by default?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it doesn't matter though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is not but yes it doesn't matter here

if items:
_logger.warning(
"There remain elements in the cache dir that you may "
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def construct_nominal_fgraph(
(
local_inputs,
local_outputs,
(clone_d, update_d, update_expr, new_shared_inputs),
(_clone_d, update_d, update_expr, new_shared_inputs),
) = new

assert len(local_inputs) == len(inputs) + len(implicit_shared_inputs)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/function/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.graph import Variable


__all__ = ["types", "pfunc"]
__all__ = ["pfunc", "types"]

__docformat__ = "restructuredtext en"
_logger = logging.getLogger("pytensor.compile.function")
Expand Down
5 changes: 2 additions & 3 deletions pytensor/compile/function/pfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,8 +328,7 @@ def clone_inputs(i):
cloned_outputs = [] # TODO: get Function.__call__ to return None
else:
raise TypeError(
"output must be an PyTensor Variable or Out "
"instance (or list of them)",
"output must be an PyTensor Variable or Out instance (or list of them)",
outputs,
)

Expand Down Expand Up @@ -592,7 +591,7 @@ def construct_pfunc_ins_and_outs(
clone_inner_graphs=True,
)
input_variables, cloned_extended_outputs, other_stuff = output_vars
clone_d, update_d, update_expr, shared_inputs = other_stuff
clone_d, update_d, _update_expr, shared_inputs = other_stuff

# Recover only the clones of the original outputs
if outputs is None:
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/function/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ def add_supervisor_to_fgraph(
input
for spec, input in zip(input_specs, fgraph.inputs, strict=True)
if not (
spec.mutable or has_destroy_handler and fgraph.has_destroyers([input])
spec.mutable or (has_destroy_handler and fgraph.has_destroyers([input]))
)
)
)
Expand Down
2 changes: 1 addition & 1 deletion pytensor/compile/mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ def register(self, *optimizations):
optimizations.
"""

link, opt = self.get_linker_optimizer(
_link, opt = self.get_linker_optimizer(
self.provided_linker, self.provided_optimizer
)
return self.clone(optimizer=opt.register(*optimizations))
Expand Down
3 changes: 1 addition & 2 deletions pytensor/compile/monitormode.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ def __init__(
optimizer = config.optimizer
if linker is not None and not isinstance(linker.mode, MonitorMode):
raise Exception(
"MonitorMode can only use its own linker! You "
"should not provide one.",
"MonitorMode can only use its own linker! You should not provide one.",
linker,
)

Expand Down
63 changes: 26 additions & 37 deletions pytensor/compile/profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,14 +514,12 @@ def summary_class(self, file=sys.stderr, N=None):
# While this carries over less information, it is arranged such
# that it is way more readable than the previous output of the
# profiler
nb_classes = max(0, len(otimes) - N)
percent_total = sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:])
time_total = sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])
print(
" ... (remaining %i Classes account for %6.2f%%(%.2fs) of "
"the runtime)"
% (
max(0, len(otimes) - N),
sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]),
sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:]),
),
f" ... (remaining {nb_classes} Classes account for "
f"{percent_total:6.2f}%%({time_total:.2f}s) of the runtime)",
file=file,
)
print("", file=file)
Expand Down Expand Up @@ -607,14 +605,12 @@ def summary_ops(self, file=sys.stderr, N=None):
# While this carries over less information, it is arranged such
# that it is way more readable than the previous output of the
# profiler
nb_ops = max(0, len(otimes) - N)
percent_total = sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:])
time_total = sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:])
print(
" ... (remaining %i Ops account for %6.2f%%(%.2fs) of "
"the runtime)"
% (
max(0, len(otimes) - N),
sum(f for f, t, a, ci, nb_call, nb_op in otimes[N:]),
sum(t for f, t, a, ci, nb_call, nb_op in otimes[N:]),
),
f" ... (remaining {nb_ops} Ops account for "
f"{percent_total:6.2f}%%({time_total:.2f}s) of the runtime)",
file=file,
)
print("", file=file)
Expand Down Expand Up @@ -935,18 +931,14 @@ def count_running_memory(order, fgraph, nodes_mem, ignore_dmap=False):
if dmap and idx2 in dmap:
vidx = dmap[idx2]
assert len(vidx) == 1, (
"Here we only support the "
"possibility to destroy one "
"input"
"Here we only support the possibility to destroy one input"
)
ins = node.inputs[vidx[0]]
if vmap and idx2 in vmap:
assert ins is None
vidx = vmap[idx2]
assert len(vidx) == 1, (
"Here we only support the "
"possibility to view one "
"input"
"Here we only support the possibility to view one input"
)
ins = node.inputs[vidx[0]]
if ins is not None:
Expand Down Expand Up @@ -1093,9 +1085,7 @@ def min_memory_generator(executable_nodes, viewed_by, view_of):
assert ins is None
vidx = vmap[idx]
assert len(vidx) == 1, (
"Here we only support "
"the possibility to "
"view one input"
"Here we only support the possibility to view one input"
)
ins = node.inputs[vidx[0]]
if ins is not None:
Expand Down Expand Up @@ -1304,22 +1294,22 @@ def print_stats(stats1, stats2):

print(
(
f" CPU: {int(round(new_max_running_max_memory_size[1] / 1024.0))}KB "
f"({int(round(max_running_max_memory_size[1] / 1024.0))}KB)"
f" CPU: {round(new_max_running_max_memory_size[1] / 1024.0)}KB "
f"({round(max_running_max_memory_size[1] / 1024.0)}KB)"
),
file=file,
)
print(
(
f" GPU: {int(round(new_max_running_max_memory_size[2] / 1024.0))}KB "
f"({int(round(max_running_max_memory_size[2] / 1024.0))}KB)"
f" GPU: {round(new_max_running_max_memory_size[2] / 1024.0)}KB "
f"({round(max_running_max_memory_size[2] / 1024.0)}KB)"
),
file=file,
)
print(
(
f" CPU + GPU: {int(round(new_max_running_max_memory_size[0] / 1024.0))}KB "
f"({int(round(max_running_max_memory_size[0] / 1024.0))}KB)"
f" CPU + GPU: {round(new_max_running_max_memory_size[0] / 1024.0)}KB "
f"({round(max_running_max_memory_size[0] / 1024.0)}KB)"
),
file=file,
)
Expand All @@ -1333,30 +1323,30 @@ def print_stats(stats1, stats2):
)
print_stats(stats[1], stats[3])

(max_node_memory_size, _, _, _) = stats[0]
(_max_node_memory_size, _, _, _) = stats[0]
(new_max_node_memory_size, _, _, _) = stats[2]
print(
" Max peak memory if allow_gc=False (linker don't make a difference)",
file=file,
)
print(
f" CPU: {int(round(new_max_node_memory_size[1] / 1024.0))}KB",
f" CPU: {round(new_max_node_memory_size[1] / 1024.0)}KB",
file=file,
)
print(
f" GPU: {int(round(new_max_node_memory_size[2] / 1024.0))}KB",
f" GPU: {round(new_max_node_memory_size[2] / 1024.0)}KB",
file=file,
)
print(
f" CPU + GPU: {int(round(new_max_node_memory_size[0] / 1024.0))}KB",
f" CPU + GPU: {round(new_max_node_memory_size[0] / 1024.0)}KB",
file=file,
)
print("---", file=file)

if min_max_peak:
print(
" Minimum peak from all valid apply node order is "
f"{int(round(min_max_peak / 1024.0))}KB(took {min_peak_time:3f}s to compute)",
f"{round(min_max_peak / 1024.0)}KB(took {min_peak_time:3f}s to compute)",
file=file,
)

Expand Down Expand Up @@ -1405,7 +1395,7 @@ def print_stats(stats1, stats2):
print(
(
f" ... (remaining {max(0, len(node_mem) - N)} Apply account for "
f"{sum_remaining:4d}B/{size_sum_dense :d}B ({p}) of the"
f"{sum_remaining:4d}B/{size_sum_dense:d}B ({p}) of the"
" Apply with dense outputs sizes)"
),
file=file,
Expand Down Expand Up @@ -1545,8 +1535,7 @@ def amdlibm_speed_up(op):
return True
elif s_op.__class__ not in scalar_op_amdlibm_no_speed_up:
print(
"We don't know if amdlibm will accelerate "
"this scalar op.",
"We don't know if amdlibm will accelerate this scalar op.",
s_op,
file=file,
)
Expand Down
3 changes: 1 addition & 2 deletions pytensor/configdefaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ def _warn_cxx(val):
"""We only support clang++ as otherwise we hit strange g++/OSX bugs."""
if sys.platform == "darwin" and val and "clang++" not in val:
_logger.warning(
"Only clang++ is supported. With g++,"
" we end up with strange g++/OSX bugs."
"Only clang++ is supported. With g++, we end up with strange g++/OSX bugs."
)
return True

Expand Down
9 changes: 4 additions & 5 deletions pytensor/gradient.py
Original file line number Diff line number Diff line change
Expand Up @@ -624,8 +624,7 @@ def grad(

if cost is not None and isinstance(cost.type, NullType):
raise ValueError(
"Can't differentiate a NaN cost. "
f"Cost is NaN because {cost.type.why_null}"
f"Can't differentiate a NaN cost. Cost is NaN because {cost.type.why_null}"
)

if cost is not None and cost.type.ndim != 0:
Expand Down Expand Up @@ -2199,9 +2198,9 @@ def hessian(cost, wrt, consider_constant=None, disconnected_inputs="raise"):
sequences=pytensor.tensor.arange(expr.shape[0]),
non_sequences=[expr, input],
)
assert (
not updates
), "Scan has returned a list of updates; this should not happen."
assert not updates, (
"Scan has returned a list of updates; this should not happen."
)
hessians.append(hess)
return as_list_or_tuple(using_list, using_tuple, hessians)

Expand Down
18 changes: 8 additions & 10 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,19 +427,17 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):

c = a + b # create a simple expression

f = pytensor.function(
[b], [c]
) # this works because a has a value associated with it already
# this works because a has a value associated with it already
f = pytensor.function([b], [c])

assert 4.0 == f(2.5) # bind 2.5 to an internal copy of b and evaluate an internal c
# bind 2.5 to an internal copy of b and evaluate an internal c
assert 4.0 == f(2.5)

pytensor.function(
[a], [c]
) # compilation error because b (required by c) is undefined
# compilation error because b (required by c) is undefined
pytensor.function([a], [c])

pytensor.function(
[a, b], [c]
) # compilation error because a is constant, it can't be an input
# compilation error because a is constant, it can't be an input
pytensor.function([a, b], [c])


The python variables ``a, b, c`` all refer to instances of type
Expand Down
3 changes: 1 addition & 2 deletions pytensor/graph/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,8 +391,7 @@ def __init__(self):
def on_attach(self, fgraph):
if hasattr(fgraph, "checkpoint") or hasattr(fgraph, "revert"):
raise AlreadyThere(
"History feature is already present or in"
" conflict with another plugin."
"History feature is already present or in conflict with another plugin."
)
self.history[fgraph] = []
# Don't call unpickle here, as ReplaceValidate.on_attach()
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/fg.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ def remove_node(self, node: Apply, reason: str | None = None):

out_clients = clients.get(out, ())
while out_clients:
out_client, out_idx = out_clients.pop()
out_client, _out_idx = out_clients.pop()

if isinstance(out_client.op, Output):
self.remove_output(out_client.op.idx, remove_client=False)
Expand Down
7 changes: 2 additions & 5 deletions pytensor/graph/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -2827,7 +2827,7 @@ def local_recursive_function(
else:
out_index = 0

final_outs, rewritten_nodes = local_recursive_function(rewrites, out, {}, 0)
final_outs, _rewritten_nodes = local_recursive_function(rewrites, out, {}, 0)
return final_outs[out_index]


Expand Down Expand Up @@ -2975,10 +2975,7 @@ def check_stack_trace(f_or_fgraph, ops_to_check="last", bug_print="raise"):
raise ValueError("ops_to_check does not have the right type")

if not apply_nodes_to_check:
msg = (
"Provided op instances/classes are not in the graph or the "
"graph is empty"
)
msg = "Provided op instances/classes are not in the graph or the graph is empty"
if bug_print == "warn":
warnings.warn(msg)
elif bug_print == "raise":
Expand Down
2 changes: 1 addition & 1 deletion pytensor/graph/rewriting/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def query(self, *tags, position_cutoff: int | float | None = None, **kwtags):
return ret

def print_summary(self, stream=sys.stdout):
print(f"{self.__class__.__name__ } (id {id(self)})", file=stream)
print(f"{self.__class__.__name__} (id {id(self)})", file=stream)
positions = list(self.__position__.items())

def c(a, b):
Expand Down
2 changes: 1 addition & 1 deletion pytensor/link/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -509,7 +509,7 @@ def make_thunk(self, **kwargs):
kwargs.pop("input_storage", None)
make_all += [x.make_all(**kwargs) for x in self.linkers[1:]]

fns, input_lists, output_lists, thunk_lists, order_lists = zip(
_fns, input_lists, output_lists, thunk_lists, order_lists = zip(
*make_all, strict=True
)

Expand Down
Loading