13
13
)
14
14
from functools import wraps
15
15
from io import StringIO
16
+ from pathlib import Path
16
17
17
18
from pytensor .utils import hash_from_code
18
19
22
23
23
24
class PyTensorConfigWarning (Warning ):
24
25
@classmethod
25
- def warn (cls , message , stacklevel = 0 ):
26
+ def warn (cls , message : str , stacklevel : int = 0 ):
26
27
warnings .warn (message , cls , stacklevel = stacklevel + 3 )
27
28
28
29
@@ -68,7 +69,119 @@ def __exit__(self, *args):
68
69
class PyTensorConfigParser :
69
70
"""Object that holds configuration settings."""
70
71
71
- def __init__ (self , flags_dict : dict , pytensor_cfg , pytensor_raw_cfg ):
72
+ # add_basic_configvars
73
+ floatX : str
74
+ warn_float64 : str
75
+ pickle_test_value : bool
76
+ cast_policy : str
77
+ deterministic : str
78
+ device : str
79
+ force_device : bool
80
+ conv__assert_shape : bool
81
+ print_global_stats : bool
82
+ assert_no_cpu_op : str
83
+ unpickle_function : bool
84
+ # add_compile_configvars
85
+ mode : str
86
+ cxx : str
87
+ linker : str
88
+ allow_gc : bool
89
+ optimizer : str
90
+ optimizer_verbose : bool
91
+ on_opt_error : str
92
+ nocleanup : bool
93
+ on_unused_import : str
94
+ gcc__cxxflags : str
95
+ cmodule__warn_no_version : bool
96
+ cmodule__remove_gxx_opt : bool
97
+ cmodule__compilation_warning : bool
98
+ cmodule__preload_cache : bool
99
+ cmodule__age_thresh_use : int
100
+ cmodule__debug : bool
101
+ compile__wait : int
102
+ compile__timeout : int
103
+ ctc__root : str
104
+ # add_tensor_configvars
105
+ tensor__cmp_sloppy : int
106
+ lib__amblibm : bool
107
+ tensor__insert_inplace_optimizer_validate_nb : int
108
+ # add_traceback_configvars
109
+ traceback__limit : int
110
+ traceback__compile_limit : int
111
+ # add_experimental_configvars
112
+ # add_error_and_warning_configvars
113
+ warn__ignore_bug_before : int
114
+ exception_verbosity : str
115
+ # add_testvalue_and_checking_configvars
116
+ print_test_value : bool
117
+ compute_test_value : str
118
+ compute_test_value_opt : str
119
+ check_input : bool
120
+ NanGuardMode__nan_is_error : bool
121
+ NanGuardMode__inf_is_error : bool
122
+ NanGuardMode__big_is_error : bool
123
+ NanGuardMode__action : str
124
+ DebugMode__patience : int
125
+ DebugMode__check_c : bool
126
+ DebugMode__check_py : bool
127
+ DebugMode__check_finite : bool
128
+ DebugMode__check_strides : int
129
+ DebugMode__warn_input_not_reused : bool
130
+ DebugMode__check_preallocated_output : str
131
+ DebugMode__check_preallocated_output_ndim : int
132
+ profiling__time_thunks : bool
133
+ profiling__n_apply : int
134
+ profiling__n_ops : int
135
+ profiling__output_line_width : int
136
+ profiling__min_memory_size : int
137
+ profiling__min_peak_memory : bool
138
+ profiling__destination : str
139
+ profiling__debugprint : bool
140
+ profiling__ignore_first_call : bool
141
+ on_shape_error : str
142
+ # add_multiprocessing_configvars
143
+ openmp : bool
144
+ openmp_elemwise_minsize : int
145
+ # add_optimizer_configvars
146
+ optimizer_excluding : str
147
+ optimizer_including : str
148
+ optimizer_requiring : str
149
+ optdb__position_cutoff : float
150
+ optdb__max_use_ratio : float
151
+ cycle_detection : str
152
+ check_stack_trace : str
153
+ metaopt__verbose : int
154
+ metaopt__optimizer_excluding : str
155
+ metaopt__optimizer_including : str
156
+ # add_vm_configvars
157
+ profile : bool
158
+ profile_optimizer : bool
159
+ profile_memory : bool
160
+ vm__lazy : bool | None
161
+ # add_deprecated_configvars
162
+ unittests__rseed : str
163
+ warn__round : bool
164
+ # add_scan_configvars
165
+ scan__allow_gc : bool
166
+ scan__allow_output_prealloc : bool
167
+ # add_numba_configvars
168
+ numba__vectorize_target : str
169
+ numba__fastmath : bool
170
+ numba__cache : bool
171
+ # add_caching_dir_configvars
172
+ compiledir_format : str
173
+ base_compiledir : Path
174
+ compiledir : Path
175
+ # add_blas_configvars
176
+ blas__ldflags : str
177
+ blas__check_openmp : bool
178
+
179
+ def __init__ (
180
+ self ,
181
+ flags_dict : dict ,
182
+ pytensor_cfg : ConfigParser ,
183
+ pytensor_raw_cfg : RawConfigParser ,
184
+ ):
72
185
self ._flags_dict = flags_dict
73
186
self ._pytensor_cfg = pytensor_cfg
74
187
self ._pytensor_raw_cfg = pytensor_raw_cfg
@@ -80,7 +193,7 @@ def __str__(self, print_doc=True):
80
193
self .config_print (buf = sio , print_doc = print_doc )
81
194
return sio .getvalue ()
82
195
83
- def config_print (self , buf , print_doc = True ):
196
+ def config_print (self , buf , print_doc : bool = True ):
84
197
for cv in self ._config_var_dict .values ():
85
198
print (cv , file = buf )
86
199
if print_doc :
@@ -108,7 +221,9 @@ def get_config_hash(self):
108
221
)
109
222
)
110
223
111
- def add (self , name , doc , configparam , in_c_key = True ):
224
+ def add (
225
+ self , name : str , doc : str , configparam : "ConfigParam" , in_c_key : bool = True
226
+ ):
112
227
"""Add a new variable to PyTensorConfigParser.
113
228
114
229
This method performs some of the work of initializing `ConfigParam` instances.
@@ -168,7 +283,7 @@ def add(self, name, doc, configparam, in_c_key=True):
168
283
# the ConfigParam implements __get__/__set__, enabling us to create a property:
169
284
setattr (self .__class__ , name , configparam )
170
285
171
- def fetch_val_for_key (self , key , delete_key = False ):
286
+ def fetch_val_for_key (self , key , delete_key : bool = False ):
172
287
"""Return the overriding config value for a key.
173
288
A successful search returns a string value.
174
289
An unsuccessful search raises a KeyError
@@ -260,9 +375,9 @@ def __init__(
260
375
self ._mutable = mutable
261
376
self .is_default = True
262
377
# set by PyTensorConfigParser.add:
263
- self .name = None
264
- self .doc = None
265
- self .in_c_key = None
378
+ self .name : str
379
+ self .doc : str
380
+ self .in_c_key : bool
266
381
267
382
# Note that we do not call `self.filter` on the default value: this
268
383
# will be done automatically in PyTensorConfigParser.add, potentially with a
@@ -336,7 +451,7 @@ def __set__(self, cls, val):
336
451
337
452
class EnumStr (ConfigParam ):
338
453
def __init__ (
339
- self , default : str , options : Sequence [str ], validate = None , mutable = True
454
+ self , default : str , options : Sequence [str ], validate = None , mutable : bool = True
340
455
):
341
456
"""Creates a str-based parameter that takes a predefined set of options.
342
457
@@ -400,7 +515,7 @@ class BoolParam(TypedParam):
400
515
True, 1, "true", "True", "1"
401
516
"""
402
517
403
- def __init__ (self , default , validate = None , mutable = True ):
518
+ def __init__ (self , default , validate = None , mutable : bool = True ):
404
519
super ().__init__ (default , apply = self ._apply , validate = validate , mutable = mutable )
405
520
406
521
def _apply (self , value ):
@@ -454,7 +569,9 @@ def _apply(self, val):
454
569
return val
455
570
456
571
457
- def parse_config_string (config_string , issue_warnings = True ):
572
+ def parse_config_string (
573
+ config_string : str , issue_warnings : bool = True
574
+ ) -> dict [str , str ]:
458
575
"""
459
576
Parses a config string (comma-separated key=value components) into a dict.
460
577
"""
@@ -480,7 +597,7 @@ def parse_config_string(config_string, issue_warnings=True):
480
597
return config_dict
481
598
482
599
483
- def config_files_from_pytensorrc ():
600
+ def config_files_from_pytensorrc () -> list [ Path ] :
484
601
"""
485
602
PYTENSORRC can contain a colon-delimited list of config files, like
486
603
@@ -489,17 +606,17 @@ def config_files_from_pytensorrc():
489
606
In that case, definitions in files on the right (here, ``~/.pytensorrc``)
490
607
have precedence over those in files on the left.
491
608
"""
492
- rval = [
493
- os . path . expanduser (s )
609
+ paths = [
610
+ Path ( s ). expanduser ()
494
611
for s in os .getenv ("PYTENSORRC" , "~/.pytensorrc" ).split (os .pathsep )
495
612
]
496
613
if os .getenv ("PYTENSORRC" ) is None and sys .platform == "win32" :
497
614
# to don't need to change the filename and make it open easily
498
- rval .append (os . path . expanduser ("~/.pytensorrc.txt" ))
499
- return rval
615
+ paths .append (Path ("~/.pytensorrc.txt" ). expanduser ( ))
616
+ return paths
500
617
501
618
502
- def _create_default_config ():
619
+ def _create_default_config () -> PyTensorConfigParser :
503
620
# The PYTENSOR_FLAGS environment variable should be a list of comma-separated
504
621
# [section__]option=value entries. If the section part is omitted, there should
505
622
# be only one section that contains the given option.
@@ -509,7 +626,7 @@ def _create_default_config():
509
626
config_files = config_files_from_pytensorrc ()
510
627
pytensor_cfg = ConfigParser (
511
628
{
512
- "USER" : os .getenv ("USER" , os . path . split ( os . path . expanduser ( "~" ))[ - 1 ] ),
629
+ "USER" : os .getenv ("USER" , Path ( "~" ). expanduser (). name ),
513
630
"LSCRATCH" : os .getenv ("LSCRATCH" , "" ),
514
631
"TMPDIR" : os .getenv ("TMPDIR" , "" ),
515
632
"TEMP" : os .getenv ("TEMP" , "" ),
0 commit comments