13
13
)
14
14
from functools import wraps
15
15
from io import StringIO
16
+ from pathlib import Path
16
17
17
18
from pytensor .utils import hash_from_code
18
19
22
23
23
24
class PyTensorConfigWarning (Warning ):
24
25
@classmethod
25
- def warn (cls , message , stacklevel = 0 ):
26
+ def warn (cls , message : str , stacklevel : int = 0 ):
26
27
warnings .warn (message , cls , stacklevel = stacklevel + 3 )
27
28
28
29
@@ -68,7 +69,123 @@ def __exit__(self, *args):
68
69
class PyTensorConfigParser :
69
70
"""Object that holds configuration settings."""
70
71
71
- def __init__ (self , flags_dict : dict , pytensor_cfg , pytensor_raw_cfg ):
72
+ # add_basic_configvars
73
+ floatX : str
74
+ warn_float64 : str
75
+ pickle_test_value : bool
76
+ cast_policy : str
77
+ deterministic : str
78
+ device : str
79
+ force_device : bool
80
+ conv__assert_shape : bool
81
+ print_global_stats : bool
82
+ assert_no_cpu_op : str
83
+ unpickle_function : bool
84
+ # add_compile_configvars
85
+ mode : str
86
+ cxx : str
87
+ linker : str
88
+ allow_gc : bool
89
+ optimizer : str
90
+ optimizer_verbose : bool
91
+ on_opt_error : str
92
+ nocleanup : bool
93
+ on_unused_import : str
94
+ gcc__cxxflags : str
95
+ cmodule__warn_no_version : bool
96
+ cmodule__remove_gxx_opt : bool
97
+ cmodule__compilation_warning : bool
98
+ cmodule__preload_cache : bool
99
+ cmodule__age_thresh_use : int
100
+ cmodule__debug : bool
101
+ compile__wait : int
102
+ compile__timeout : int
103
+ ctc__root : str
104
+ # add_tensor_configvars
105
+ tensor__cmp_sloppy : int
106
+ lib__amblibm : bool
107
+ tensor__insert_inplace_optimizer_validate_nb : int
108
+ # add_traceback_configvars
109
+ traceback__limit : int
110
+ traceback__compile_limit : int
111
+ # add_experimental_configvars
112
+ # add_error_and_warning_configvars
113
+ warn__ignore_bug_before : int
114
+ exception_verbosity : str
115
+ # add_testvalue_and_checking_configvars
116
+ print_test_value : bool
117
+ compute_test_value : str
118
+ compute_test_value_opt : str
119
+ check_input : bool
120
+ NanGuardMode__nan_is_error : bool
121
+ NanGuardMode__inf_is_error : bool
122
+ NanGuardMode__big_is_error : bool
123
+ NanGuardMode__action : str
124
+ DebugMode__patience : int
125
+ DebugMode__check_c : bool
126
+ DebugMode__check_py : bool
127
+ DebugMode__check_finite : bool
128
+ DebugMode__check_strides : int
129
+ DebugMode__warn_input_not_reused : bool
130
+ DebugMode__check_preallocated_output : str
131
+ DebugMode__check_preallocated_output_ndim : int
132
+ profiling__time_thunks : bool
133
+ profiling__n_apply : int
134
+ profiling__n_ops : int
135
+ profiling__output_line_width : int
136
+ profiling__min_memory_size : int
137
+ profiling__min_peak_memory : bool
138
+ profiling__destination : str
139
+ profiling__debugprint : bool
140
+ profiling__ignore_first_call : bool
141
+ on_shape_error : str
142
+ # add_multiprocessing_configvars
143
+ openmp : bool
144
+ openmp_elemwise_minsize : int
145
+ # add_optimizer_configvars
146
+ optimizer_excluding : str
147
+ optimizer_including : str
148
+ optimizer_requiring : str
149
+ optdb__position_cutoff : float
150
+ optdb__max_use_ratio : float
151
+ cycle_detection : str
152
+ check_stack_trace : str
153
+ metaopt__verbose : int
154
+ metaopt__optimizer_excluding : str
155
+ metaopt__optimizer_including : str
156
+ # add_vm_configvars
157
+ profile : bool
158
+ profile_optimizer : bool
159
+ profile_memory : bool
160
+ vm__lazy : bool | None
161
+ # add_deprecated_configvars
162
+ unittests__rseed : str
163
+ warn__round : bool
164
+ # add_scan_configvars
165
+ scan__allow_gc : bool
166
+ scan__allow_output_prealloc : bool
167
+ # add_numba_configvars
168
+ numba__vectorize_target : str
169
+ numba__fastmath : bool
170
+ numba__cache : bool
171
+ # add_caching_dir_configvars
172
+ compiledir_format : str
173
+ base_compiledir : Path
174
+ compiledir : Path
175
+ # add_blas_configvars
176
+ blas__ldflags : str
177
+ blas__check_openmp : bool
178
+ # add CUDA (?)
179
+ cuda__root : Path | None
180
+ dnn__base_path : Path | None
181
+ dnn__library_path : Path | None
182
+
183
+ def __init__ (
184
+ self ,
185
+ flags_dict : dict ,
186
+ pytensor_cfg : ConfigParser ,
187
+ pytensor_raw_cfg : RawConfigParser ,
188
+ ):
72
189
self ._flags_dict = flags_dict
73
190
self ._pytensor_cfg = pytensor_cfg
74
191
self ._pytensor_raw_cfg = pytensor_raw_cfg
@@ -80,7 +197,7 @@ def __str__(self, print_doc=True):
80
197
self .config_print (buf = sio , print_doc = print_doc )
81
198
return sio .getvalue ()
82
199
83
- def config_print (self , buf , print_doc = True ):
200
+ def config_print (self , buf , print_doc : bool = True ):
84
201
for cv in self ._config_var_dict .values ():
85
202
print (cv , file = buf )
86
203
if print_doc :
@@ -108,7 +225,9 @@ def get_config_hash(self):
108
225
)
109
226
)
110
227
111
- def add (self , name , doc , configparam , in_c_key = True ):
228
+ def add (
229
+ self , name : str , doc : str , configparam : "ConfigParam" , in_c_key : bool = True
230
+ ):
112
231
"""Add a new variable to PyTensorConfigParser.
113
232
114
233
This method performs some of the work of initializing `ConfigParam` instances.
@@ -168,7 +287,7 @@ def add(self, name, doc, configparam, in_c_key=True):
168
287
# the ConfigParam implements __get__/__set__, enabling us to create a property:
169
288
setattr (self .__class__ , name , configparam )
170
289
171
- def fetch_val_for_key (self , key , delete_key = False ):
290
+ def fetch_val_for_key (self , key , delete_key : bool = False ):
172
291
"""Return the overriding config value for a key.
173
292
A successful search returns a string value.
174
293
An unsuccessful search raises a KeyError
@@ -260,9 +379,9 @@ def __init__(
260
379
self ._mutable = mutable
261
380
self .is_default = True
262
381
# set by PyTensorConfigParser.add:
263
- self .name = None
264
- self .doc = None
265
- self .in_c_key = None
382
+ self .name : str = "unnamed"
383
+ self .doc : str = "undocumented"
384
+ self .in_c_key : bool
266
385
267
386
# Note that we do not call `self.filter` on the default value: this
268
387
# will be done automatically in PyTensorConfigParser.add, potentially with a
@@ -288,7 +407,7 @@ def apply(self, value):
288
407
return self ._apply (value )
289
408
return value
290
409
291
- def validate (self , value ) -> bool | None :
410
+ def validate (self , value ) -> bool :
292
411
"""Validates that a parameter values falls into a supported set or range.
293
412
294
413
Raises
@@ -336,7 +455,7 @@ def __set__(self, cls, val):
336
455
337
456
class EnumStr (ConfigParam ):
338
457
def __init__ (
339
- self , default : str , options : Sequence [str ], validate = None , mutable = True
458
+ self , default : str , options : Sequence [str ], validate = None , mutable : bool = True
340
459
):
341
460
"""Creates a str-based parameter that takes a predefined set of options.
342
461
@@ -400,7 +519,7 @@ class BoolParam(TypedParam):
400
519
True, 1, "true", "True", "1"
401
520
"""
402
521
403
- def __init__ (self , default , validate = None , mutable = True ):
522
+ def __init__ (self , default , validate = None , mutable : bool = True ):
404
523
super ().__init__ (default , apply = self ._apply , validate = validate , mutable = mutable )
405
524
406
525
def _apply (self , value ):
@@ -454,7 +573,9 @@ def _apply(self, val):
454
573
return val
455
574
456
575
457
- def parse_config_string (config_string , issue_warnings = True ):
576
+ def parse_config_string (
577
+ config_string : str , issue_warnings : bool = True
578
+ ) -> dict [str , str ]:
458
579
"""
459
580
Parses a config string (comma-separated key=value components) into a dict.
460
581
"""
@@ -480,7 +601,7 @@ def parse_config_string(config_string, issue_warnings=True):
480
601
return config_dict
481
602
482
603
483
- def config_files_from_pytensorrc ():
604
+ def config_files_from_pytensorrc () -> list [ Path ] :
484
605
"""
485
606
PYTENSORRC can contain a colon-delimited list of config files, like
486
607
@@ -489,17 +610,17 @@ def config_files_from_pytensorrc():
489
610
In that case, definitions in files on the right (here, ``~/.pytensorrc``)
490
611
have precedence over those in files on the left.
491
612
"""
492
- rval = [
493
- os . path . expanduser (s )
613
+ paths = [
614
+ Path ( s ). expanduser ()
494
615
for s in os .getenv ("PYTENSORRC" , "~/.pytensorrc" ).split (os .pathsep )
495
616
]
496
617
if os .getenv ("PYTENSORRC" ) is None and sys .platform == "win32" :
497
618
# to don't need to change the filename and make it open easily
498
- rval .append (os . path . expanduser ("~/.pytensorrc.txt" ))
499
- return rval
619
+ paths .append (Path ("~/.pytensorrc.txt" ). expanduser ( ))
620
+ return paths
500
621
501
622
502
- def _create_default_config ():
623
+ def _create_default_config () -> PyTensorConfigParser :
503
624
# The PYTENSOR_FLAGS environment variable should be a list of comma-separated
504
625
# [section__]option=value entries. If the section part is omitted, there should
505
626
# be only one section that contains the given option.
@@ -509,7 +630,7 @@ def _create_default_config():
509
630
config_files = config_files_from_pytensorrc ()
510
631
pytensor_cfg = ConfigParser (
511
632
{
512
- "USER" : os .getenv ("USER" , os . path . split ( os . path . expanduser ( "~" ))[ - 1 ] ),
633
+ "USER" : os .getenv ("USER" , Path ( "~" ). expanduser (). name ),
513
634
"LSCRATCH" : os .getenv ("LSCRATCH" , "" ),
514
635
"TMPDIR" : os .getenv ("TMPDIR" , "" ),
515
636
"TEMP" : os .getenv ("TEMP" , "" ),
0 commit comments