99import os
1010from pathlib import Path
1111import platform
12- from typing import Tuple , cast , Any , Optional
12+ from typing import Dict , Optional , Tuple , Union , cast
1313import warnings
1414
1515from io import StringIO
1919logger = logging .getLogger (__name__ )
2020openml_logger = logging .getLogger ("openml" )
2121console_handler = None
22- file_handler = None
22+ file_handler = None # type: Optional[logging.Handler]
2323
2424
25- def _create_log_handlers (create_file_handler = True ):
25+ def _create_log_handlers (create_file_handler : bool = True ) -> None :
2626 """Creates but does not attach the log handlers."""
2727 global console_handler , file_handler
2828 if console_handler is not None or file_handler is not None :
@@ -61,7 +61,7 @@ def _convert_log_levels(log_level: int) -> Tuple[int, int]:
6161 return openml_level , python_level
6262
6363
64- def _set_level_register_and_store (handler : logging .Handler , log_level : int ):
64+ def _set_level_register_and_store (handler : logging .Handler , log_level : int ) -> None :
6565 """Set handler log level, register it if needed, save setting to config file if specified."""
6666 oml_level , py_level = _convert_log_levels (log_level )
6767 handler .setLevel (py_level )
@@ -73,13 +73,13 @@ def _set_level_register_and_store(handler: logging.Handler, log_level: int):
7373 openml_logger .addHandler (handler )
7474
7575
76- def set_console_log_level (console_output_level : int ):
76+ def set_console_log_level (console_output_level : int ) -> None :
7777 """Set console output to the desired level and register it with openml logger if needed."""
7878 global console_handler
7979 _set_level_register_and_store (cast (logging .Handler , console_handler ), console_output_level )
8080
8181
82- def set_file_log_level (file_output_level : int ):
82+ def set_file_log_level (file_output_level : int ) -> None :
8383 """Set file output to the desired level and register it with openml logger if needed."""
8484 global file_handler
8585 _set_level_register_and_store (cast (logging .Handler , file_handler ), file_output_level )
@@ -139,7 +139,8 @@ def set_retry_policy(value: str, n_retries: Optional[int] = None) -> None:
139139
140140 if value not in default_retries_by_policy :
141141 raise ValueError (
142- f"Detected retry_policy '{ value } ' but must be one of { default_retries_by_policy } "
142+ f"Detected retry_policy '{ value } ' but must be one of "
143+ f"{ list (default_retries_by_policy .keys ())} "
143144 )
144145 if n_retries is not None and not isinstance (n_retries , int ):
145146 raise TypeError (f"`n_retries` must be of type `int` or `None` but is `{ type (n_retries )} `." )
@@ -160,7 +161,7 @@ class ConfigurationForExamples:
160161 _test_apikey = "c0c42819af31e706efe1f4b88c23c6c1"
161162
162163 @classmethod
163- def start_using_configuration_for_example (cls ):
164+ def start_using_configuration_for_example (cls ) -> None :
164165 """Sets the configuration to connect to the test server with valid apikey.
165166
166167 To configuration as was before this call is stored, and can be recovered
@@ -187,7 +188,7 @@ def start_using_configuration_for_example(cls):
187188 )
188189
189190 @classmethod
190- def stop_using_configuration_for_example (cls ):
191+ def stop_using_configuration_for_example (cls ) -> None :
191192 """Return to configuration as it was before `start_use_example_configuration`."""
192193 if not cls ._start_last_called :
193194 # We don't want to allow this because it will (likely) result in the `server` and
@@ -200,8 +201,8 @@ def stop_using_configuration_for_example(cls):
200201 global server
201202 global apikey
202203
203- server = cls ._last_used_server
204- apikey = cls ._last_used_key
204+ server = cast ( str , cls ._last_used_server )
205+ apikey = cast ( str , cls ._last_used_key )
205206 cls ._start_last_called = False
206207
207208
@@ -215,7 +216,7 @@ def determine_config_file_path() -> Path:
215216 return config_dir / "config"
216217
217218
218- def _setup (config = None ):
219+ def _setup (config : Optional [ Dict [ str , Union [ str , int , bool ]]] = None ) -> None :
219220 """Setup openml package. Called on first import.
220221
221222 Reads the config file and sets up apikey, server, cache appropriately.
@@ -243,28 +244,22 @@ def _setup(config=None):
243244 cache_exists = True
244245
245246 if config is None :
246- config = _parse_config (config_file )
247+ config = cast (Dict [str , Union [str , int , bool ]], _parse_config (config_file ))
248+ config = cast (Dict [str , Union [str , int , bool ]], config )
247249
248- def _get (config , key ):
249- return config .get ("FAKE_SECTION" , key )
250+ avoid_duplicate_runs = bool (config .get ("avoid_duplicate_runs" ))
250251
251- avoid_duplicate_runs = config .getboolean ("FAKE_SECTION" , "avoid_duplicate_runs" )
252- else :
253-
254- def _get (config , key ):
255- return config .get (key )
256-
257- avoid_duplicate_runs = config .get ("avoid_duplicate_runs" )
252+ apikey = cast (str , config ["apikey" ])
253+ server = cast (str , config ["server" ])
254+ short_cache_dir = cast (str , config ["cachedir" ])
258255
259- apikey = _get (config , "apikey" )
260- server = _get (config , "server" )
261- short_cache_dir = _get (config , "cachedir" )
262-
263- n_retries = _get (config , "connection_n_retries" )
264- if n_retries is not None :
265- n_retries = int (n_retries )
256+ tmp_n_retries = config ["connection_n_retries" ]
257+ if tmp_n_retries is not None :
258+ n_retries = int (tmp_n_retries )
259+ else :
260+ n_retries = None
266261
267- set_retry_policy (_get ( config , "retry_policy" ), n_retries )
262+ set_retry_policy (cast ( str , config [ "retry_policy" ] ), n_retries )
268263
269264 _root_cache_directory = os .path .expanduser (short_cache_dir )
270265 # create the cache subdirectory
@@ -287,10 +282,10 @@ def _get(config, key):
287282 )
288283
289284
290- def set_field_in_config_file (field : str , value : Any ) :
285+ def set_field_in_config_file (field : str , value : str ) -> None :
291286 """Overwrites the `field` in the configuration file with the new `value`."""
292287 if field not in _defaults :
293- return ValueError (f"Field '{ field } ' is not valid and must be one of '{ _defaults .keys ()} '." )
288+ raise ValueError (f"Field '{ field } ' is not valid and must be one of '{ _defaults .keys ()} '." )
294289
295290 globals ()[field ] = value
296291 config_file = determine_config_file_path ()
@@ -308,7 +303,7 @@ def set_field_in_config_file(field: str, value: Any):
308303 fh .write (f"{ f } = { value } \n " )
309304
310305
311- def _parse_config (config_file : str ) :
306+ def _parse_config (config_file : Union [ str , Path ]) -> Dict [ str , str ] :
312307 """Parse the config file, set up defaults."""
313308 config = configparser .RawConfigParser (defaults = _defaults )
314309
@@ -326,11 +321,12 @@ def _parse_config(config_file: str):
326321 logger .info ("Error opening file %s: %s" , config_file , e .args [0 ])
327322 config_file_ .seek (0 )
328323 config .read_file (config_file_ )
329- return config
324+ config_as_dict = {key : value for key , value in config .items ("FAKE_SECTION" )}
325+ return config_as_dict
330326
331327
332- def get_config_as_dict ():
333- config = dict ()
328+ def get_config_as_dict () -> Dict [ str , Union [ str , int , bool ]] :
329+ config = dict () # type: Dict[str, Union[str, int, bool]]
334330 config ["apikey" ] = apikey
335331 config ["server" ] = server
336332 config ["cachedir" ] = _root_cache_directory
@@ -340,7 +336,7 @@ def get_config_as_dict():
340336 return config
341337
342338
343- def get_cache_directory ():
339+ def get_cache_directory () -> str :
344340 """Get the current cache directory.
345341
346342 This gets the cache directory for the current server relative
@@ -366,7 +362,7 @@ def get_cache_directory():
366362 return _cachedir
367363
368364
369- def set_root_cache_directory (root_cache_directory ) :
365+ def set_root_cache_directory (root_cache_directory : str ) -> None :
370366 """Set module-wide base cache directory.
371367
372368 Sets the root cache directory, wherin the cache directories are
0 commit comments