Skip to content

Commit 4ad9d52

Browse files
committed
Remove os.path and add type hints to configparser.py
1 parent d62260c commit 4ad9d52

File tree

4 files changed

+156
-28
lines changed

4 files changed

+156
-28
lines changed

pytensor/configparser.py

Lines changed: 140 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from functools import wraps
1515
from io import StringIO
16+
from pathlib import Path
1617

1718
from pytensor.utils import hash_from_code
1819

@@ -22,7 +23,7 @@
2223

2324
class PyTensorConfigWarning(Warning):
2425
@classmethod
25-
def warn(cls, message, stacklevel=0):
26+
def warn(cls, message: str, stacklevel: int = 0):
2627
warnings.warn(message, cls, stacklevel=stacklevel + 3)
2728

2829

@@ -68,7 +69,123 @@ def __exit__(self, *args):
6869
class PyTensorConfigParser:
6970
"""Object that holds configuration settings."""
7071

71-
def __init__(self, flags_dict: dict, pytensor_cfg, pytensor_raw_cfg):
72+
# add_basic_configvars
73+
floatX: str
74+
warn_float64: str
75+
pickle_test_value: bool
76+
cast_policy: str
77+
deterministic: str
78+
device: str
79+
force_device: bool
80+
conv__assert_shape: bool
81+
print_global_stats: bool
82+
assert_no_cpu_op: str
83+
unpickle_function: bool
84+
# add_compile_configvars
85+
mode: str
86+
cxx: str
87+
linker: str
88+
allow_gc: bool
89+
optimizer: str
90+
optimizer_verbose: bool
91+
on_opt_error: str
92+
nocleanup: bool
93+
on_unused_import: str
94+
gcc__cxxflags: str
95+
cmodule__warn_no_version: bool
96+
cmodule__remove_gxx_opt: bool
97+
cmodule__compilation_warning: bool
98+
cmodule__preload_cache: bool
99+
cmodule__age_thresh_use: int
100+
cmodule__debug: bool
101+
compile__wait: int
102+
compile__timeout: int
103+
ctc__root: str
104+
# add_tensor_configvars
105+
tensor__cmp_sloppy: int
106+
lib__amblibm: bool
107+
tensor__insert_inplace_optimizer_validate_nb: int
108+
# add_traceback_configvars
109+
traceback__limit: int
110+
traceback__compile_limit: int
111+
# add_experimental_configvars
112+
# add_error_and_warning_configvars
113+
warn__ignore_bug_before: int
114+
exception_verbosity: str
115+
# add_testvalue_and_checking_configvars
116+
print_test_value: bool
117+
compute_test_value: str
118+
compute_test_value_opt: str
119+
check_input: bool
120+
NanGuardMode__nan_is_error: bool
121+
NanGuardMode__inf_is_error: bool
122+
NanGuardMode__big_is_error: bool
123+
NanGuardMode__action: str
124+
DebugMode__patience: int
125+
DebugMode__check_c: bool
126+
DebugMode__check_py: bool
127+
DebugMode__check_finite: bool
128+
DebugMode__check_strides: int
129+
DebugMode__warn_input_not_reused: bool
130+
DebugMode__check_preallocated_output: str
131+
DebugMode__check_preallocated_output_ndim: int
132+
profiling__time_thunks: bool
133+
profiling__n_apply: int
134+
profiling__n_ops: int
135+
profiling__output_line_width: int
136+
profiling__min_memory_size: int
137+
profiling__min_peak_memory: bool
138+
profiling__destination: str
139+
profiling__debugprint: bool
140+
profiling__ignore_first_call: bool
141+
on_shape_error: str
142+
# add_multiprocessing_configvars
143+
openmp: bool
144+
openmp_elemwise_minsize: int
145+
# add_optimizer_configvars
146+
optimizer_excluding: str
147+
optimizer_including: str
148+
optimizer_requiring: str
149+
optdb__position_cutoff: float
150+
optdb__max_use_ratio: float
151+
cycle_detection: str
152+
check_stack_trace: str
153+
metaopt__verbose: int
154+
metaopt__optimizer_excluding: str
155+
metaopt__optimizer_including: str
156+
# add_vm_configvars
157+
profile: bool
158+
profile_optimizer: bool
159+
profile_memory: bool
160+
vm__lazy: bool | None
161+
# add_deprecated_configvars
162+
unittests__rseed: str
163+
warn__round: bool
164+
# add_scan_configvars
165+
scan__allow_gc: bool
166+
scan__allow_output_prealloc: bool
167+
# add_numba_configvars
168+
numba__vectorize_target: str
169+
numba__fastmath: bool
170+
numba__cache: bool
171+
# add_caching_dir_configvars
172+
compiledir_format: str
173+
base_compiledir: Path
174+
compiledir: Path
175+
# add_blas_configvars
176+
blas__ldflags: str
177+
blas__check_openmp: bool
178+
# add CUDA (?)
179+
cuda__root: Path | None
180+
dnn__base_path: Path | None
181+
dnn__library_path: Path | None
182+
183+
def __init__(
184+
self,
185+
flags_dict: dict,
186+
pytensor_cfg: ConfigParser,
187+
pytensor_raw_cfg: RawConfigParser,
188+
):
72189
self._flags_dict = flags_dict
73190
self._pytensor_cfg = pytensor_cfg
74191
self._pytensor_raw_cfg = pytensor_raw_cfg
@@ -80,7 +197,7 @@ def __str__(self, print_doc=True):
80197
self.config_print(buf=sio, print_doc=print_doc)
81198
return sio.getvalue()
82199

83-
def config_print(self, buf, print_doc=True):
200+
def config_print(self, buf, print_doc: bool = True):
84201
for cv in self._config_var_dict.values():
85202
print(cv, file=buf)
86203
if print_doc:
@@ -108,7 +225,9 @@ def get_config_hash(self):
108225
)
109226
)
110227

111-
def add(self, name, doc, configparam, in_c_key=True):
228+
def add(
229+
self, name: str, doc: str, configparam: "ConfigParam", in_c_key: bool = True
230+
):
112231
"""Add a new variable to PyTensorConfigParser.
113232
114233
This method performs some of the work of initializing `ConfigParam` instances.
@@ -168,7 +287,7 @@ def add(self, name, doc, configparam, in_c_key=True):
168287
# the ConfigParam implements __get__/__set__, enabling us to create a property:
169288
setattr(self.__class__, name, configparam)
170289

171-
def fetch_val_for_key(self, key, delete_key=False):
290+
def fetch_val_for_key(self, key, delete_key: bool = False):
172291
"""Return the overriding config value for a key.
173292
A successful search returns a string value.
174293
An unsuccessful search raises a KeyError
@@ -260,9 +379,9 @@ def __init__(
260379
self._mutable = mutable
261380
self.is_default = True
262381
# set by PyTensorConfigParser.add:
263-
self.name = None
264-
self.doc = None
265-
self.in_c_key = None
382+
self.name: str = "unnamed"
383+
self.doc: str = "undocumented"
384+
self.in_c_key: bool
266385

267386
# Note that we do not call `self.filter` on the default value: this
268387
# will be done automatically in PyTensorConfigParser.add, potentially with a
@@ -288,7 +407,7 @@ def apply(self, value):
288407
return self._apply(value)
289408
return value
290409

291-
def validate(self, value) -> bool | None:
410+
def validate(self, value) -> bool:
292411
"""Validates that a parameter values falls into a supported set or range.
293412
294413
Raises
@@ -336,7 +455,7 @@ def __set__(self, cls, val):
336455

337456
class EnumStr(ConfigParam):
338457
def __init__(
339-
self, default: str, options: Sequence[str], validate=None, mutable=True
458+
self, default: str, options: Sequence[str], validate=None, mutable: bool = True
340459
):
341460
"""Creates a str-based parameter that takes a predefined set of options.
342461
@@ -400,7 +519,7 @@ class BoolParam(TypedParam):
400519
True, 1, "true", "True", "1"
401520
"""
402521

403-
def __init__(self, default, validate=None, mutable=True):
522+
def __init__(self, default, validate=None, mutable: bool = True):
404523
super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
405524

406525
def _apply(self, value):
@@ -454,7 +573,9 @@ def _apply(self, val):
454573
return val
455574

456575

457-
def parse_config_string(config_string, issue_warnings=True):
576+
def parse_config_string(
577+
config_string: str, issue_warnings: bool = True
578+
) -> dict[str, str]:
458579
"""
459580
Parses a config string (comma-separated key=value components) into a dict.
460581
"""
@@ -480,7 +601,7 @@ def parse_config_string(config_string, issue_warnings=True):
480601
return config_dict
481602

482603

483-
def config_files_from_pytensorrc():
604+
def config_files_from_pytensorrc() -> list[Path]:
484605
"""
485606
PYTENSORRC can contain a colon-delimited list of config files, like
486607
@@ -489,17 +610,17 @@ def config_files_from_pytensorrc():
489610
In that case, definitions in files on the right (here, ``~/.pytensorrc``)
490611
have precedence over those in files on the left.
491612
"""
492-
rval = [
493-
os.path.expanduser(s)
613+
paths = [
614+
Path(s).expanduser()
494615
for s in os.getenv("PYTENSORRC", "~/.pytensorrc").split(os.pathsep)
495616
]
496617
if os.getenv("PYTENSORRC") is None and sys.platform == "win32":
497618
# to don't need to change the filename and make it open easily
498-
rval.append(os.path.expanduser("~/.pytensorrc.txt"))
499-
return rval
619+
paths.append(Path("~/.pytensorrc.txt").expanduser())
620+
return paths
500621

501622

502-
def _create_default_config():
623+
def _create_default_config() -> PyTensorConfigParser:
503624
# The PYTENSOR_FLAGS environment variable should be a list of comma-separated
504625
# [section__]option=value entries. If the section part is omitted, there should
505626
# be only one section that contains the given option.
@@ -509,7 +630,7 @@ def _create_default_config():
509630
config_files = config_files_from_pytensorrc()
510631
pytensor_cfg = ConfigParser(
511632
{
512-
"USER": os.getenv("USER", os.path.split(os.path.expanduser("~"))[-1]),
633+
"USER": os.getenv("USER", Path("~").expanduser().name),
513634
"LSCRATCH": os.getenv("LSCRATCH", ""),
514635
"TMPDIR": os.getenv("TMPDIR", ""),
515636
"TEMP": os.getenv("TEMP", ""),

pytensor/link/c/cmodule.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from collections.abc import Callable
2424
from contextlib import AbstractContextManager, nullcontext
2525
from io import BytesIO, StringIO
26+
from pathlib import Path
2627
from typing import TYPE_CHECKING, Protocol, cast
2728

2829
import numpy as np
@@ -688,7 +689,7 @@ class ModuleCache:
688689
689690
"""
690691

691-
dirname: str = ""
692+
dirname: Path
692693
"""
693694
The working directory that is managed by this interface.
694695
@@ -725,8 +726,13 @@ class ModuleCache:
725726
726727
"""
727728

728-
def __init__(self, dirname, check_for_broken_eq=True, do_refresh=True):
729-
self.dirname = dirname
729+
def __init__(
730+
self,
731+
dirname: Path | str,
732+
check_for_broken_eq: bool = True,
733+
do_refresh: bool = True,
734+
):
735+
self.dirname = Path(dirname)
730736
self.module_from_name = dict(self.module_from_name)
731737
self.entry_from_key = dict(self.entry_from_key)
732738
self.module_hash_to_key_data = dict(self.module_hash_to_key_data)
@@ -1637,12 +1643,12 @@ def _rmtree(
16371643
_module_cache: ModuleCache | None = None
16381644

16391645

1640-
def get_module_cache(dirname: str, init_args=None) -> ModuleCache:
1646+
def get_module_cache(dirname: Path | str, init_args=None) -> ModuleCache:
16411647
"""Create a new module_cache.
16421648
16431649
Parameters
16441650
----------
1645-
dirname
1651+
dirname : Path | str
16461652
The name of the directory used by the cache.
16471653
init_args
16481654
Keyword arguments passed to the `ModuleCache` constructor.

pytensor/link/c/cutils.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import errno
22
import os
33
import sys
4+
from pathlib import Path
45

56
from pytensor.compile.compilelock import lock_ctx
67
from pytensor.configdefaults import config
@@ -87,8 +88,8 @@ def compile_cutils():
8788
# for the same reason. Note that these 5 lines may seem redundant (they are
8889
# repeated in compile_str()) but if another cutils_ext does exist then it
8990
# will be imported and compile_str won't get called at all.
90-
sys.path.insert(0, config.compiledir)
91-
location = os.path.join(config.compiledir, "cutils_ext")
91+
sys.path.insert(0, str(config.compiledir))
92+
location = os.path.join(str(config.compiledir), "cutils_ext")
9293
if not os.path.exists(location):
9394
try:
9495
os.mkdir(location)
@@ -115,5 +116,5 @@ def compile_cutils():
115116
compile_cutils()
116117
from cutils_ext.cutils_ext import * # noqa
117118
finally:
118-
if sys.path[0] == config.compiledir:
119+
if config.compiledir.resolve() == Path(sys.path[0]).resolve():
119120
del sys.path[0]

tests/test_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def test_deviceparam(self):
138138
cp._apply("gpu123")
139139
with pytest.raises(ValueError, match='Valid options start with one of "cpu".'):
140140
cp._apply("notadevice")
141-
assert str(cp) == "None (cpu)"
141+
assert str(cp) == "unnamed (cpu)"
142142

143143

144144
def test_config_context():

0 commit comments

Comments
 (0)