1212from io import StringIO
1313from pathlib import Path
1414from typing import Dict , Union , cast
15+ from typing_extensions import Literal
1516from urllib .parse import urlparse
1617
1718logger = logging .getLogger (__name__ )
1819openml_logger = logging .getLogger ("openml" )
19- console_handler = None
20- file_handler = None # type: Optional[logging.Handler]
20+ console_handler : logging . StreamHandler | None = None
21+ file_handler : logging . handlers . RotatingFileHandler | None = None
2122
2223
23- def _create_log_handlers (create_file_handler : bool = True ) -> None :
24+ def _create_log_handlers (create_file_handler : bool = True ) -> None : # noqa: FBT
2425 """Creates but does not attach the log handlers."""
25- global console_handler , file_handler
26+ global console_handler , file_handler # noqa: PLW0603
2627 if console_handler is not None or file_handler is not None :
2728 logger .debug ("Requested to create log handlers, but they are already created." )
2829 return
@@ -35,7 +36,7 @@ def _create_log_handlers(create_file_handler: bool = True) -> None:
3536
3637 if create_file_handler :
3738 one_mb = 2 ** 20
38- log_path = os . path . join ( _root_cache_directory , "openml_python.log" )
39+ log_path = _root_cache_directory / "openml_python.log"
3940 file_handler = logging .handlers .RotatingFileHandler (
4041 log_path ,
4142 maxBytes = one_mb ,
@@ -64,7 +65,7 @@ def _convert_log_levels(log_level: int) -> tuple[int, int]:
6465
6566def _set_level_register_and_store (handler : logging .Handler , log_level : int ) -> None :
6667 """Set handler log level, register it if needed, save setting to config file if specified."""
67- oml_level , py_level = _convert_log_levels (log_level )
68+ _oml_level , py_level = _convert_log_levels (log_level )
6869 handler .setLevel (py_level )
6970
7071 if openml_logger .level > py_level or openml_logger .level == logging .NOTSET :
@@ -76,31 +77,27 @@ def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> N
7677
7778def set_console_log_level (console_output_level : int ) -> None :
7879 """Set console output to the desired level and register it with openml logger if needed."""
79- global console_handler
80- _set_level_register_and_store (cast (logging .Handler , console_handler ), console_output_level )
80+ global console_handler # noqa: PLW0602
81+ assert console_handler is not None
82+ _set_level_register_and_store (console_handler , console_output_level )
8183
8284
8385def set_file_log_level (file_output_level : int ) -> None :
8486 """Set file output to the desired level and register it with openml logger if needed."""
85- global file_handler
86- _set_level_register_and_store (cast (logging .Handler , file_handler ), file_output_level )
87+ global file_handler # noqa: PLW0602
88+ assert file_handler is not None
89+ _set_level_register_and_store (file_handler , file_output_level )
8790
8891
8992# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards)
93+ _user_path = Path ("~" ).expanduser ().absolute ()
9094_defaults = {
9195 "apikey" : "" ,
9296 "server" : "https://www.openml.org/api/v1/xml" ,
9397 "cachedir" : (
94- os .environ .get (
95- "XDG_CACHE_HOME" ,
96- os .path .join (
97- "~" ,
98- ".cache" ,
99- "openml" ,
100- ),
101- )
98+ os .environ .get ("XDG_CACHE_HOME" , _user_path / ".cache" / "openml" )
10299 if platform .system () == "Linux"
103- else os . path . join ( "~" , ".openml" )
100+ else _user_path / ".openml"
104101 ),
105102 "avoid_duplicate_runs" : "True" ,
106103 "retry_policy" : "human" ,
@@ -124,18 +121,18 @@ def get_server_base_url() -> str:
124121 return server .split ("/api" )[0 ]
125122
126123
127- apikey = _defaults ["apikey" ]
124+ apikey : str = _defaults ["apikey" ]
128125# The current cache directory (without the server name)
129- _root_cache_directory = str (_defaults ["cachedir" ]) # so mypy knows it is a string
130- avoid_duplicate_runs = _defaults ["avoid_duplicate_runs" ] == "True"
126+ _root_cache_directory = Path (_defaults ["cachedir" ])
127+ avoid_duplicate_runs : bool = _defaults ["avoid_duplicate_runs" ] == "True"
131128
132129retry_policy = _defaults ["retry_policy" ]
133130connection_n_retries = int (_defaults ["connection_n_retries" ])
134131
135132
136- def set_retry_policy (value : str , n_retries : int | None = None ) -> None :
137- global retry_policy
138- global connection_n_retries
133+ def set_retry_policy (value : Literal [ "human" , "robot" ] , n_retries : int | None = None ) -> None :
134+ global retry_policy # noqa: PLW0603
135+ global connection_n_retries # noqa: PLW0603
139136 default_retries_by_policy = {"human" : 5 , "robot" : 50 }
140137
141138 if value not in default_retries_by_policy :
@@ -145,6 +142,7 @@ def set_retry_policy(value: str, n_retries: int | None = None) -> None:
145142 )
146143 if n_retries is not None and not isinstance (n_retries , int ):
147144 raise TypeError (f"`n_retries` must be of type `int` or `None` but is `{ type (n_retries )} `." )
145+
148146 if isinstance (n_retries , int ) and n_retries < 1 :
149147 raise ValueError (f"`n_retries` is '{ n_retries } ' but must be positive." )
150148
@@ -168,8 +166,8 @@ def start_using_configuration_for_example(cls) -> None:
168166 To configuration as was before this call is stored, and can be recovered
169167 by using the `stop_use_example_configuration` method.
170168 """
171- global server
172- global apikey
169+ global server # noqa: PLW0603
170+ global apikey # noqa: PLW0603
173171
174172 if cls ._start_last_called and server == cls ._test_server and apikey == cls ._test_apikey :
175173 # Method is called more than once in a row without modifying the server or apikey.
@@ -186,6 +184,7 @@ def start_using_configuration_for_example(cls) -> None:
186184 warnings .warn (
187185 f"Switching to the test server { server } to not upload results to the live server. "
188186 "Using the test server may result in reduced performance of the API!" ,
187+ stacklevel = 2 ,
189188 )
190189
191190 @classmethod
@@ -199,8 +198,8 @@ def stop_using_configuration_for_example(cls) -> None:
199198 "`start_use_example_configuration` must be called first." ,
200199 )
201200
202- global server
203- global apikey
201+ global server # noqa: PLW0603
202+ global apikey # noqa: PLW0603
204203
205204 server = cast (str , cls ._last_used_server )
206205 apikey = cast (str , cls ._last_used_key )
@@ -213,7 +212,7 @@ def determine_config_file_path() -> Path:
213212 else :
214213 config_dir = Path ("~" ) / ".openml"
215214 # Still use os.path.expanduser to trigger the mock in the unit test
216- config_dir = Path (os . path . expanduser (config_dir ) )
215+ config_dir = Path (config_dir ). expanduser (). resolve ( )
217216 return config_dir / "config"
218217
219218
@@ -226,18 +225,18 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:
226225 openml.config.server = SOMESERVER
227226 We could also make it a property but that's less clear.
228227 """
229- global apikey
230- global server
231- global _root_cache_directory
232- global avoid_duplicate_runs
228+ global apikey # noqa: PLW0603
229+ global server # noqa: PLW0603
230+ global _root_cache_directory # noqa: PLW0603
231+ global avoid_duplicate_runs # noqa: PLW0603
233232
234233 config_file = determine_config_file_path ()
235234 config_dir = config_file .parent
236235
237236 # read config file, create directory for config file
238- if not os . path . exists (config_dir ):
237+ if not config_dir . exists ():
239238 try :
240- os . makedirs ( config_dir , exist_ok = True )
239+ config_dir . mkdir ( exist_ok = True , parents = True )
241240 cache_exists = True
242241 except PermissionError :
243242 cache_exists = False
@@ -250,20 +249,20 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:
250249
251250 avoid_duplicate_runs = bool (config .get ("avoid_duplicate_runs" ))
252251
253- apikey = cast ( str , config ["apikey" ])
254- server = cast ( str , config ["server" ])
255- short_cache_dir = cast ( str , config ["cachedir" ])
252+ apikey = str ( config ["apikey" ])
253+ server = str ( config ["server" ])
254+ short_cache_dir = Path ( config ["cachedir" ])
256255
257256 tmp_n_retries = config ["connection_n_retries" ]
258257 n_retries = int (tmp_n_retries ) if tmp_n_retries is not None else None
259258
260- set_retry_policy (cast ( str , config ["retry_policy" ]) , n_retries )
259+ set_retry_policy (config ["retry_policy" ], n_retries )
261260
262- _root_cache_directory = os . path . expanduser (short_cache_dir )
261+ _root_cache_directory = short_cache_dir . expanduser (). resolve ( )
263262 # create the cache subdirectory
264- if not os . path . exists (_root_cache_directory ):
263+ if not _root_cache_directory . exists ():
265264 try :
266- os . makedirs ( _root_cache_directory , exist_ok = True )
265+ _root_cache_directory . mkdir ( exist_ok = True , parents = True )
267266 except PermissionError :
268267 openml_logger .warning (
269268 "No permission to create openml cache directory at %s! This can result in "
@@ -288,7 +287,7 @@ def set_field_in_config_file(field: str, value: str) -> None:
288287 globals ()[field ] = value
289288 config_file = determine_config_file_path ()
290289 config = _parse_config (str (config_file ))
291- with open (config_file , "w" ) as fh :
290+ with config_file . open ("w" ) as fh :
292291 for f in _defaults :
293292 # We can't blindly set all values based on globals() because when the user
294293 # sets it through config.FIELD it should not be stored to file.
@@ -303,14 +302,15 @@ def set_field_in_config_file(field: str, value: str) -> None:
303302
304303def _parse_config (config_file : str | Path ) -> dict [str , str ]:
305304 """Parse the config file, set up defaults."""
305+ config_file = Path (config_file )
306306 config = configparser .RawConfigParser (defaults = _defaults )
307307
308308 # The ConfigParser requires a [SECTION_HEADER], which we do not expect in our config file.
309309 # Cheat the ConfigParser module by adding a fake section header
310310 config_file_ = StringIO ()
311311 config_file_ .write ("[FAKE_SECTION]\n " )
312312 try :
313- with open (config_file ) as fh :
313+ with config_file . open ("w" ) as fh :
314314 for line in fh :
315315 config_file_ .write (line )
316316 except FileNotFoundError :
@@ -326,13 +326,14 @@ def get_config_as_dict() -> dict[str, str | int | bool]:
326326 config = {} # type: Dict[str, Union[str, int, bool]]
327327 config ["apikey" ] = apikey
328328 config ["server" ] = server
329- config ["cachedir" ] = _root_cache_directory
329+ config ["cachedir" ] = str ( _root_cache_directory )
330330 config ["avoid_duplicate_runs" ] = avoid_duplicate_runs
331331 config ["connection_n_retries" ] = connection_n_retries
332332 config ["retry_policy" ] = retry_policy
333333 return config
334334
335335
336+ # NOTE: For backwards compatibility, we keep the `str`
336337def get_cache_directory () -> str :
337338 """Get the current cache directory.
338339
@@ -354,11 +355,11 @@ def get_cache_directory() -> str:
354355
355356 """
356357 url_suffix = urlparse (server ).netloc
357- reversed_url_suffix = os .sep .join (url_suffix .split ("." )[::- 1 ])
358- return os .path .join (_root_cache_directory , reversed_url_suffix )
358+ reversed_url_suffix = os .sep .join (url_suffix .split ("." )[::- 1 ]) # noqa: PTH118
359+ return os .path .join (_root_cache_directory , reversed_url_suffix ) # noqa: PTH118
359360
360361
361- def set_root_cache_directory (root_cache_directory : str ) -> None :
362+ def set_root_cache_directory (root_cache_directory : str | Path ) -> None :
362363 """Set module-wide base cache directory.
363364
364365 Sets the root cache directory, wherin the cache directories are
@@ -377,8 +378,8 @@ def set_root_cache_directory(root_cache_directory: str) -> None:
377378 --------
378379 get_cache_directory
379380 """
380- global _root_cache_directory
381- _root_cache_directory = root_cache_directory
381+ global _root_cache_directory # noqa: PLW0603
382+ _root_cache_directory = Path ( root_cache_directory )
382383
383384
384385start_using_configuration_for_example = (
0 commit comments