Skip to content

Commit 910e0cf

Browse files
committed
Remove os.path and add type hints to configparser.py
1 parent 8090d8a commit 910e0cf

File tree

3 files changed

+149
-26
lines changed

3 files changed

+149
-26
lines changed

pytensor/configparser.py

Lines changed: 135 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)
1414
from functools import wraps
1515
from io import StringIO
16+
from pathlib import Path
1617

1718
from pytensor.utils import hash_from_code
1819

@@ -22,7 +23,7 @@
2223

2324
class PyTensorConfigWarning(Warning):
2425
@classmethod
25-
def warn(cls, message, stacklevel=0):
26+
def warn(cls, message: str, stacklevel: int = 0):
2627
warnings.warn(message, cls, stacklevel=stacklevel + 3)
2728

2829

@@ -68,7 +69,119 @@ def __exit__(self, *args):
6869
class PyTensorConfigParser:
6970
"""Object that holds configuration settings."""
7071

71-
def __init__(self, flags_dict: dict, pytensor_cfg, pytensor_raw_cfg):
72+
# add_basic_configvars
73+
floatX: str
74+
warn_float64: str
75+
pickle_test_value: bool
76+
cast_policy: str
77+
deterministic: str
78+
device: str
79+
force_device: bool
80+
conv__assert_shape: bool
81+
print_global_stats: bool
82+
assert_no_cpu_op: str
83+
unpickle_function: bool
84+
# add_compile_configvars
85+
mode: str
86+
cxx: str
87+
linker: str
88+
allow_gc: bool
89+
optimizer: str
90+
optimizer_verbose: bool
91+
on_opt_error: str
92+
nocleanup: bool
93+
on_unused_import: str
94+
gcc__cxxflags: str
95+
cmodule__warn_no_version: bool
96+
cmodule__remove_gxx_opt: bool
97+
cmodule__compilation_warning: bool
98+
cmodule__preload_cache: bool
99+
cmodule__age_thresh_use: int
100+
cmodule__debug: bool
101+
compile__wait: int
102+
compile__timeout: int
103+
ctc__root: str
104+
# add_tensor_configvars
105+
tensor__cmp_sloppy: int
106+
lib__amblibm: bool
107+
tensor__insert_inplace_optimizer_validate_nb: int
108+
# add_traceback_configvars
109+
traceback__limit: int
110+
traceback__compile_limit: int
111+
# add_experimental_configvars
112+
# add_error_and_warning_configvars
113+
warn__ignore_bug_before: int
114+
exception_verbosity: str
115+
# add_testvalue_and_checking_configvars
116+
print_test_value: bool
117+
compute_test_value: str
118+
compute_test_value_opt: str
119+
check_input: bool
120+
NanGuardMode__nan_is_error: bool
121+
NanGuardMode__inf_is_error: bool
122+
NanGuardMode__big_is_error: bool
123+
NanGuardMode__action: str
124+
DebugMode__patience: int
125+
DebugMode__check_c: bool
126+
DebugMode__check_py: bool
127+
DebugMode__check_finite: bool
128+
DebugMode__check_strides: int
129+
DebugMode__warn_input_not_reused: bool
130+
DebugMode__check_preallocated_output: str
131+
DebugMode__check_preallocated_output_ndim: int
132+
profiling__time_thunks: bool
133+
profiling__n_apply: int
134+
profiling__n_ops: int
135+
profiling__output_line_width: int
136+
profiling__min_memory_size: int
137+
profiling__min_peak_memory: bool
138+
profiling__destination: str
139+
profiling__debugprint: bool
140+
profiling__ignore_first_call: bool
141+
on_shape_error: str
142+
# add_multiprocessing_configvars
143+
openmp: bool
144+
openmp_elemwise_minsize: int
145+
# add_optimizer_configvars
146+
optimizer_excluding: str
147+
optimizer_including: str
148+
optimizer_requiring: str
149+
optdb__position_cutoff: float
150+
optdb__max_use_ratio: float
151+
cycle_detection: str
152+
check_stack_trace: str
153+
metaopt__verbose: int
154+
metaopt__optimizer_excluding: str
155+
metaopt__optimizer_including: str
156+
# add_vm_configvars
157+
profile: bool
158+
profile_optimizer: bool
159+
profile_memory: bool
160+
vm__lazy: bool | None
161+
# add_deprecated_configvars
162+
unittests__rseed: str
163+
warn__round: bool
164+
# add_scan_configvars
165+
scan__allow_gc: bool
166+
scan__allow_output_prealloc: bool
167+
# add_numba_configvars
168+
numba__vectorize_target: str
169+
numba__fastmath: bool
170+
numba__cache: bool
171+
# add_caching_dir_configvars
172+
compiledir_format: str
173+
base_compiledir: Path
174+
compiledir: Path
175+
# add_blas_configvars
176+
blas__ldflags: str
177+
blas__check_openmp: bool
178+
179+
def __init__(
180+
self,
181+
flags_dict: dict,
182+
pytensor_cfg: ConfigParser,
183+
pytensor_raw_cfg: RawConfigParser,
184+
):
72185
self._flags_dict = flags_dict
73186
self._pytensor_cfg = pytensor_cfg
74187
self._pytensor_raw_cfg = pytensor_raw_cfg
@@ -80,7 +193,7 @@ def __str__(self, print_doc=True):
80193
self.config_print(buf=sio, print_doc=print_doc)
81194
return sio.getvalue()
82195

83-
def config_print(self, buf, print_doc=True):
196+
def config_print(self, buf, print_doc: bool = True):
84197
for cv in self._config_var_dict.values():
85198
print(cv, file=buf)
86199
if print_doc:
@@ -108,7 +221,9 @@ def get_config_hash(self):
108221
)
109222
)
110223

111-
def add(self, name, doc, configparam, in_c_key=True):
224+
def add(
225+
self, name: str, doc: str, configparam: "ConfigParam", in_c_key: bool = True
226+
):
112227
"""Add a new variable to PyTensorConfigParser.
113228
114229
This method performs some of the work of initializing `ConfigParam` instances.
@@ -168,7 +283,7 @@ def add(self, name, doc, configparam, in_c_key=True):
168283
# the ConfigParam implements __get__/__set__, enabling us to create a property:
169284
setattr(self.__class__, name, configparam)
170285

171-
def fetch_val_for_key(self, key, delete_key=False):
286+
def fetch_val_for_key(self, key, delete_key: bool = False):
172287
"""Return the overriding config value for a key.
173288
A successful search returns a string value.
174289
An unsuccessful search raises a KeyError
@@ -260,9 +375,9 @@ def __init__(
260375
self._mutable = mutable
261376
self.is_default = True
262377
# set by PyTensorConfigParser.add:
263-
self.name = None
264-
self.doc = None
265-
self.in_c_key = None
378+
self.name: str
379+
self.doc: str
380+
self.in_c_key: bool
266381

267382
# Note that we do not call `self.filter` on the default value: this
268383
# will be done automatically in PyTensorConfigParser.add, potentially with a
@@ -336,7 +451,7 @@ def __set__(self, cls, val):
336451

337452
class EnumStr(ConfigParam):
338453
def __init__(
339-
self, default: str, options: Sequence[str], validate=None, mutable=True
454+
self, default: str, options: Sequence[str], validate=None, mutable: bool = True
340455
):
341456
"""Creates a str-based parameter that takes a predefined set of options.
342457
@@ -400,7 +515,7 @@ class BoolParam(TypedParam):
400515
True, 1, "true", "True", "1"
401516
"""
402517

403-
def __init__(self, default, validate=None, mutable=True):
518+
def __init__(self, default, validate=None, mutable: bool = True):
404519
super().__init__(default, apply=self._apply, validate=validate, mutable=mutable)
405520

406521
def _apply(self, value):
@@ -454,7 +569,9 @@ def _apply(self, val):
454569
return val
455570

456571

457-
def parse_config_string(config_string, issue_warnings=True):
572+
def parse_config_string(
573+
config_string: str, issue_warnings: bool = True
574+
) -> dict[str, str]:
458575
"""
459576
Parses a config string (comma-separated key=value components) into a dict.
460577
"""
@@ -480,7 +597,7 @@ def parse_config_string(config_string, issue_warnings=True):
480597
return config_dict
481598

482599

483-
def config_files_from_pytensorrc():
600+
def config_files_from_pytensorrc() -> list[Path]:
484601
"""
485602
PYTENSORRC can contain a colon-delimited list of config files, like
486603
@@ -489,17 +606,17 @@ def config_files_from_pytensorrc():
489606
In that case, definitions in files on the right (here, ``~/.pytensorrc``)
490607
have precedence over those in files on the left.
491608
"""
492-
rval = [
493-
os.path.expanduser(s)
609+
paths = [
610+
Path(s).expanduser()
494611
for s in os.getenv("PYTENSORRC", "~/.pytensorrc").split(os.pathsep)
495612
]
496613
if os.getenv("PYTENSORRC") is None and sys.platform == "win32":
497614
# to don't need to change the filename and make it open easily
498-
rval.append(os.path.expanduser("~/.pytensorrc.txt"))
499-
return rval
615+
paths.append(Path("~/.pytensorrc.txt").expanduser())
616+
return paths
500617

501618

502-
def _create_default_config():
619+
def _create_default_config() -> PyTensorConfigParser:
503620
# The PYTENSOR_FLAGS environment variable should be a list of comma-separated
504621
# [section__]option=value entries. If the section part is omitted, there should
505622
# be only one section that contains the given option.
@@ -509,7 +626,7 @@ def _create_default_config():
509626
config_files = config_files_from_pytensorrc()
510627
pytensor_cfg = ConfigParser(
511628
{
512-
"USER": os.getenv("USER", os.path.split(os.path.expanduser("~"))[-1]),
629+
"USER": os.getenv("USER", Path("~").expanduser().name),
513630
"LSCRATCH": os.getenv("LSCRATCH", ""),
514631
"TMPDIR": os.getenv("TMPDIR", ""),
515632
"TEMP": os.getenv("TEMP", ""),

pytensor/link/c/cmodule.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from collections.abc import Callable
2424
from contextlib import AbstractContextManager, nullcontext
2525
from io import BytesIO, StringIO
26+
from pathlib import Path
2627
from typing import TYPE_CHECKING, Protocol, cast
2728

2829
import numpy as np
@@ -688,7 +689,7 @@ class ModuleCache:
688689
689690
"""
690691

691-
dirname: str = ""
692+
dirname: Path
692693
"""
693694
The working directory that is managed by this interface.
694695
@@ -725,8 +726,13 @@ class ModuleCache:
725726
726727
"""
727728

728-
def __init__(self, dirname, check_for_broken_eq=True, do_refresh=True):
729-
self.dirname = dirname
729+
def __init__(
730+
self,
731+
dirname: Path | str,
732+
check_for_broken_eq: bool = True,
733+
do_refresh: bool = True,
734+
):
735+
self.dirname = Path(dirname)
730736
self.module_from_name = dict(self.module_from_name)
731737
self.entry_from_key = dict(self.entry_from_key)
732738
self.module_hash_to_key_data = dict(self.module_hash_to_key_data)
@@ -1637,12 +1643,12 @@ def _rmtree(
16371643
_module_cache: ModuleCache | None = None
16381644

16391645

1640-
def get_module_cache(dirname: str, init_args=None) -> ModuleCache:
1646+
def get_module_cache(dirname: Path | str, init_args=None) -> ModuleCache:
16411647
"""Create a new module_cache.
16421648
16431649
Parameters
16441650
----------
1645-
dirname
1651+
dirname : Path | str
16461652
The name of the directory used by the cache.
16471653
init_args
16481654
Keyword arguments passed to the `ModuleCache` constructor.

pytensor/link/c/cutils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,8 +87,8 @@ def compile_cutils():
8787
# for the same reason. Note that these 5 lines may seem redundant (they are
8888
# repeated in compile_str()) but if another cutils_ext does exist then it
8989
# will be imported and compile_str won't get called at all.
90-
sys.path.insert(0, config.compiledir)
91-
location = os.path.join(config.compiledir, "cutils_ext")
90+
sys.path.insert(0, str(config.compiledir))
91+
location = os.path.join(str(config.compiledir), "cutils_ext")
9292
if not os.path.exists(location):
9393
try:
9494
os.mkdir(location)
@@ -115,5 +115,5 @@ def compile_cutils():
115115
compile_cutils()
116116
from cutils_ext.cutils_ext import * # noqa
117117
finally:
118-
if sys.path[0] == config.compiledir:
118+
if config.compiledir.samefile(sys.path[0]):
119119
del sys.path[0]

0 commit comments

Comments
 (0)