Skip to content

Commit 6a295b9

Browse files
committed
Remove unused stuff and type pytensor/utils.py
1 parent b083fb9 commit 6a295b9

File tree

4 files changed

+41
-58
lines changed

4 files changed

+41
-58
lines changed

pytensor/link/c/cmodule.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ def is_same_entry(entry_1, entry_2):
426426
return False
427427

428428

429-
def get_module_hash(src_code, key):
429+
def get_module_hash(src_code: str, key) -> str:
430430
"""
431431
Return a SHA256 hash that uniquely identifies a module.
432432
@@ -466,13 +466,13 @@ def get_module_hash(src_code, key):
466466
if isinstance(key_element, tuple):
467467
# This should be the C++ compilation command line parameters or the
468468
# libraries to link against.
469-
to_hash += list(key_element)
469+
to_hash += [str(e) for e in key_element]
470470
elif isinstance(key_element, str):
471471
if key_element.startswith("md5:") or key_element.startswith("hash:"):
472472
# This is actually a sha256 hash of the config options.
473473
# Currently, we still keep md5 to don't break old PyTensor.
474474
# We add 'hash:' so that when we change it in
475-
# the futur, it won't break this version of PyTensor.
475+
# the future, it won't break this version of PyTensor.
476476
break
477477
elif key_element.startswith("NPY_ABI_VERSION=0x") or key_element.startswith(
478478
"c_compiler_str="

pytensor/tensor/blas.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
7676
"""
7777

78+
import functools
7879
import logging
7980
import os
8081
import time
@@ -104,7 +105,6 @@
104105
from pytensor.tensor.math import add, mul, neg, sub
105106
from pytensor.tensor.shape import shape_padright, specify_broadcastable
106107
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
107-
from pytensor.utils import memoize
108108

109109

110110
_logger = logging.getLogger("pytensor.tensor.blas")
@@ -365,8 +365,10 @@ def ldflags(libs=True, flags=False, libs_dir=False, include_dir=False):
365365
)
366366

367367

368-
@memoize
369-
def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
368+
@functools.cache
369+
def _ldflags(
370+
ldflags_str: str, libs: bool, flags: bool, libs_dir: bool, include_dir: bool
371+
) -> list[str]:
370372
"""Extract list of compilation flags from a string.
371373
372374
Depending on the options, different type of flags will be kept.
@@ -422,7 +424,7 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
422424
t = t[1:-1]
423425

424426
try:
425-
t0, t1, t2 = t[0:3]
427+
t0, t1 = t[0], t[1]
426428
assert t0 == "-"
427429
except Exception:
428430
raise ValueError(f'invalid token "{t}" in ldflags_str: "{ldflags_str}"')
@@ -435,7 +437,6 @@ def _ldflags(ldflags_str, libs, flags, libs_dir, include_dir):
435437
" is not wanted.",
436438
t,
437439
)
438-
rval.append(t[2:])
439440
elif libs and t1 == "l": # example -lmkl
440441
rval.append(t[2:])
441442
elif flags and t1 not in ("L", "I", "l"): # example -openmp

pytensor/utils.py

Lines changed: 21 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
import struct
77
import subprocess
88
import sys
9+
from collections.abc import Iterable, Sequence
910
from functools import partial
11+
from pathlib import Path
1012

1113

1214
__all__ = [
@@ -85,18 +87,6 @@ def add_excepthook(hook):
8587
sys.excepthook = __call_excepthooks
8688

8789

88-
def exc_message(e):
89-
"""
90-
In python 3.x, when an exception is reraised it saves original
91-
exception in its args, therefore in order to find the actual
92-
message, we need to unpack arguments recursively.
93-
"""
94-
msg = e.args[0]
95-
if isinstance(msg, Exception):
96-
return exc_message(msg)
97-
return msg
98-
99-
10090
def get_unbound_function(unbound):
10191
# Op.make_thunk isn't bound, so don't have a __func__ attr.
10292
# But bound method, have a __func__ method that point to the
@@ -106,8 +96,9 @@ def get_unbound_function(unbound):
10696
return unbound
10797

10898

109-
def maybe_add_to_os_environ_pathlist(var, newpath):
110-
"""Unfortunately, Conda offers to make itself the default Python
99+
def maybe_add_to_os_environ_pathlist(var: str, newpath: Path | str) -> None:
100+
"""
101+
Unfortunately, Conda offers to make itself the default Python
111102
and those who use it that way will probably not activate envs
112103
correctly meaning e.g. mingw-w64 g++ may not be on their PATH.
113104
@@ -118,18 +109,18 @@ def maybe_add_to_os_environ_pathlist(var, newpath):
118109
The reason we check first is because Windows environment vars
119110
are limited to 8191 characters and it is easy to hit that.
120111
121-
`var` will typically be 'PATH'."""
122-
123-
import os
112+
`var` will typically be 'PATH'.
113+
"""
114+
if not Path(newpath).is_absolute():
115+
return
124116

125-
if os.path.isabs(newpath):
126-
try:
127-
oldpaths = os.environ[var].split(os.pathsep)
128-
if newpath not in oldpaths:
129-
newpaths = os.pathsep.join([newpath, *oldpaths])
130-
os.environ[var] = newpaths
131-
except Exception:
132-
pass
117+
try:
118+
oldpaths = os.environ[var].split(os.pathsep)
119+
if str(newpath) not in oldpaths:
120+
newpaths = os.pathsep.join([str(newpath), *oldpaths])
121+
os.environ[var] = newpaths
122+
except Exception:
123+
pass
133124

134125

135126
def subprocess_Popen(command, **params):
@@ -210,7 +201,7 @@ def output_subprocess_Popen(command, **params):
210201
return (*out, p.returncode)
211202

212203

213-
def hash_from_code(msg):
204+
def hash_from_code(msg: str | bytes) -> str:
214205
"""Return the SHA256 hash of a string or bytes."""
215206
# hashlib.sha256() requires an object that supports buffer interface,
216207
# but Python 3 (unicode) strings don't.
@@ -221,27 +212,7 @@ def hash_from_code(msg):
221212
return "m" + hashlib.sha256(msg).hexdigest()
222213

223214

224-
def memoize(f):
225-
"""
226-
Cache the return value for each tuple of arguments (which must be hashable).
227-
228-
"""
229-
cache = {}
230-
231-
def rval(*args, **kwargs):
232-
kwtup = tuple(kwargs.items())
233-
key = (args, kwtup)
234-
if key not in cache:
235-
val = f(*args, **kwargs)
236-
cache[key] = val
237-
else:
238-
val = cache[key]
239-
return val
240-
241-
return rval
242-
243-
244-
def uniq(seq):
215+
def uniq(seq: Sequence) -> list:
245216
"""
246217
Do not use set, this must always return the same value at the same index.
247218
If we just exchange other values, but keep the same pattern of duplication,
@@ -253,11 +224,12 @@ def uniq(seq):
253224
return [x for i, x in enumerate(seq) if seq.index(x) == i]
254225

255226

256-
def difference(seq1, seq2):
227+
def difference(seq1: Iterable, seq2: Iterable):
257228
r"""
258229
Returns all elements in seq1 which are not in seq2: i.e ``seq1\seq2``.
259230
260231
"""
232+
seq2 = list(seq2)
261233
try:
262234
# try to use O(const * len(seq1)) algo
263235
if len(seq2) < 4: # I'm guessing this threshold -JB
@@ -285,7 +257,7 @@ def from_return_values(values):
285257
return [values]
286258

287259

288-
def flatten(a):
260+
def flatten(a) -> list:
289261
"""
290262
Recursively flatten tuple, list and set in a list.
291263

tests/compile/function/test_types.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131
scalars,
3232
vector,
3333
)
34-
from pytensor.utils import exc_message
3534

3635

3736
def PatternOptimizer(p1, p2, ign=True):
@@ -1182,6 +1181,17 @@ def pers_save(obj):
11821181
def pers_load(id):
11831182
return saves[id]
11841183

1184+
def exc_message(e):
1185+
"""
1186+
In Python 3, when an exception is reraised it saves the original
1187+
exception in its args, therefore in order to find the actual
1188+
message, we need to unpack arguments recursively.
1189+
"""
1190+
msg = e.args[0]
1191+
if isinstance(msg, Exception):
1192+
return exc_message(msg)
1193+
return msg
1194+
11851195
b = np.random.random((5, 4))
11861196

11871197
x = matrix()

0 commit comments

Comments
 (0)