Skip to content

Commit 43c66aa

Browse files
eddiebergmanLennartPuruckerpre-commit-ci[bot]
authored
fix: Chipping away at ruff lints (#1303)
* fix: Chipping away at ruff lints * fix: return lockfile path * Update openml/config.py * Update openml/runs/functions.py * Update openml/tasks/functions.py * Update openml/tasks/split.py * Update openml/utils.py * Update openml/utils.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update openml/config.py * Update openml/testing.py * Update openml/utils.py * Update openml/config.py * Update openml/utils.py * Update openml/utils.py * add concurrency to workflow calls * adjust docstring * adjust docstring * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: Lennart Purucker <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Lennart Purucker <[email protected]>
1 parent e435706 commit 43c66aa

File tree

14 files changed

+446
-345
lines changed

14 files changed

+446
-345
lines changed

.github/workflows/test.yml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ name: Tests
22

33
on: [push, pull_request]
44

5+
concurrency:
6+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
7+
cancel-in-progress: true
8+
59
jobs:
610
test:
711
name: (${{ matrix.os }}, Py${{ matrix.python-version }}, sk${{ matrix.scikit-learn }}, sk-only:${{ matrix.sklearn-only }})

openml/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -117,4 +117,5 @@ def populate_cache(task_ids=None, dataset_ids=None, flow_ids=None, run_ids=None)
117117
]
118118

119119
# Load the scikit-learn extension by default
120-
import openml.extensions.sklearn # noqa: F401
120+
# TODO(eddiebergman): Not sure why this is at the bottom of the file
121+
import openml.extensions.sklearn # noqa: E402, F401

openml/config.py

Lines changed: 52 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,18 @@
1212
from io import StringIO
1313
from pathlib import Path
1414
from typing import Dict, Union, cast
15+
from typing_extensions import Literal
1516
from urllib.parse import urlparse
1617

1718
logger = logging.getLogger(__name__)
1819
openml_logger = logging.getLogger("openml")
19-
console_handler = None
20-
file_handler = None # type: Optional[logging.Handler]
20+
console_handler: logging.StreamHandler | None = None
21+
file_handler: logging.handlers.RotatingFileHandler | None = None
2122

2223

23-
def _create_log_handlers(create_file_handler: bool = True) -> None:
24+
def _create_log_handlers(create_file_handler: bool = True) -> None: # noqa: FBT
2425
"""Creates but does not attach the log handlers."""
25-
global console_handler, file_handler
26+
global console_handler, file_handler # noqa: PLW0603
2627
if console_handler is not None or file_handler is not None:
2728
logger.debug("Requested to create log handlers, but they are already created.")
2829
return
@@ -35,7 +36,7 @@ def _create_log_handlers(create_file_handler: bool = True) -> None:
3536

3637
if create_file_handler:
3738
one_mb = 2**20
38-
log_path = os.path.join(_root_cache_directory, "openml_python.log")
39+
log_path = _root_cache_directory / "openml_python.log"
3940
file_handler = logging.handlers.RotatingFileHandler(
4041
log_path,
4142
maxBytes=one_mb,
@@ -64,7 +65,7 @@ def _convert_log_levels(log_level: int) -> tuple[int, int]:
6465

6566
def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> None:
6667
"""Set handler log level, register it if needed, save setting to config file if specified."""
67-
oml_level, py_level = _convert_log_levels(log_level)
68+
_oml_level, py_level = _convert_log_levels(log_level)
6869
handler.setLevel(py_level)
6970

7071
if openml_logger.level > py_level or openml_logger.level == logging.NOTSET:
@@ -76,31 +77,27 @@ def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> N
7677

7778
def set_console_log_level(console_output_level: int) -> None:
7879
"""Set console output to the desired level and register it with openml logger if needed."""
79-
global console_handler
80-
_set_level_register_and_store(cast(logging.Handler, console_handler), console_output_level)
80+
global console_handler # noqa: PLW0602
81+
assert console_handler is not None
82+
_set_level_register_and_store(console_handler, console_output_level)
8183

8284

8385
def set_file_log_level(file_output_level: int) -> None:
8486
"""Set file output to the desired level and register it with openml logger if needed."""
85-
global file_handler
86-
_set_level_register_and_store(cast(logging.Handler, file_handler), file_output_level)
87+
global file_handler # noqa: PLW0602
88+
assert file_handler is not None
89+
_set_level_register_and_store(file_handler, file_output_level)
8790

8891

8992
# Default values (see also https://github.com/openml/OpenML/wiki/Client-API-Standards)
93+
_user_path = Path("~").expanduser().absolute()
9094
_defaults = {
9195
"apikey": "",
9296
"server": "https://www.openml.org/api/v1/xml",
9397
"cachedir": (
94-
os.environ.get(
95-
"XDG_CACHE_HOME",
96-
os.path.join(
97-
"~",
98-
".cache",
99-
"openml",
100-
),
101-
)
98+
os.environ.get("XDG_CACHE_HOME", _user_path / ".cache" / "openml")
10299
if platform.system() == "Linux"
103-
else os.path.join("~", ".openml")
100+
else _user_path / ".openml"
104101
),
105102
"avoid_duplicate_runs": "True",
106103
"retry_policy": "human",
@@ -124,18 +121,18 @@ def get_server_base_url() -> str:
124121
return server.split("/api")[0]
125122

126123

127-
apikey = _defaults["apikey"]
124+
apikey: str = _defaults["apikey"]
128125
# The current cache directory (without the server name)
129-
_root_cache_directory = str(_defaults["cachedir"]) # so mypy knows it is a string
130-
avoid_duplicate_runs = _defaults["avoid_duplicate_runs"] == "True"
126+
_root_cache_directory = Path(_defaults["cachedir"])
127+
avoid_duplicate_runs: bool = _defaults["avoid_duplicate_runs"] == "True"
131128

132129
retry_policy = _defaults["retry_policy"]
133130
connection_n_retries = int(_defaults["connection_n_retries"])
134131

135132

136-
def set_retry_policy(value: str, n_retries: int | None = None) -> None:
137-
global retry_policy
138-
global connection_n_retries
133+
def set_retry_policy(value: Literal["human", "robot"], n_retries: int | None = None) -> None:
134+
global retry_policy # noqa: PLW0603
135+
global connection_n_retries # noqa: PLW0603
139136
default_retries_by_policy = {"human": 5, "robot": 50}
140137

141138
if value not in default_retries_by_policy:
@@ -145,6 +142,7 @@ def set_retry_policy(value: str, n_retries: int | None = None) -> None:
145142
)
146143
if n_retries is not None and not isinstance(n_retries, int):
147144
raise TypeError(f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`.")
145+
148146
if isinstance(n_retries, int) and n_retries < 1:
149147
raise ValueError(f"`n_retries` is '{n_retries}' but must be positive.")
150148

@@ -168,8 +166,8 @@ def start_using_configuration_for_example(cls) -> None:
168166
To configuration as was before this call is stored, and can be recovered
169167
by using the `stop_use_example_configuration` method.
170168
"""
171-
global server
172-
global apikey
169+
global server # noqa: PLW0603
170+
global apikey # noqa: PLW0603
173171

174172
if cls._start_last_called and server == cls._test_server and apikey == cls._test_apikey:
175173
# Method is called more than once in a row without modifying the server or apikey.
@@ -186,6 +184,7 @@ def start_using_configuration_for_example(cls) -> None:
186184
warnings.warn(
187185
f"Switching to the test server {server} to not upload results to the live server. "
188186
"Using the test server may result in reduced performance of the API!",
187+
stacklevel=2,
189188
)
190189

191190
@classmethod
@@ -199,8 +198,8 @@ def stop_using_configuration_for_example(cls) -> None:
199198
"`start_use_example_configuration` must be called first.",
200199
)
201200

202-
global server
203-
global apikey
201+
global server # noqa: PLW0603
202+
global apikey # noqa: PLW0603
204203

205204
server = cast(str, cls._last_used_server)
206205
apikey = cast(str, cls._last_used_key)
@@ -213,7 +212,7 @@ def determine_config_file_path() -> Path:
213212
else:
214213
config_dir = Path("~") / ".openml"
215214
# Still use os.path.expanduser to trigger the mock in the unit test
216-
config_dir = Path(os.path.expanduser(config_dir))
215+
config_dir = Path(config_dir).expanduser().resolve()
217216
return config_dir / "config"
218217

219218

@@ -226,18 +225,18 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:
226225
openml.config.server = SOMESERVER
227226
We could also make it a property but that's less clear.
228227
"""
229-
global apikey
230-
global server
231-
global _root_cache_directory
232-
global avoid_duplicate_runs
228+
global apikey # noqa: PLW0603
229+
global server # noqa: PLW0603
230+
global _root_cache_directory # noqa: PLW0603
231+
global avoid_duplicate_runs # noqa: PLW0603
233232

234233
config_file = determine_config_file_path()
235234
config_dir = config_file.parent
236235

237236
# read config file, create directory for config file
238-
if not os.path.exists(config_dir):
237+
if not config_dir.exists():
239238
try:
240-
os.makedirs(config_dir, exist_ok=True)
239+
config_dir.mkdir(exist_ok=True, parents=True)
241240
cache_exists = True
242241
except PermissionError:
243242
cache_exists = False
@@ -250,20 +249,20 @@ def _setup(config: dict[str, str | int | bool] | None = None) -> None:
250249

251250
avoid_duplicate_runs = bool(config.get("avoid_duplicate_runs"))
252251

253-
apikey = cast(str, config["apikey"])
254-
server = cast(str, config["server"])
255-
short_cache_dir = cast(str, config["cachedir"])
252+
apikey = str(config["apikey"])
253+
server = str(config["server"])
254+
short_cache_dir = Path(config["cachedir"])
256255

257256
tmp_n_retries = config["connection_n_retries"]
258257
n_retries = int(tmp_n_retries) if tmp_n_retries is not None else None
259258

260-
set_retry_policy(cast(str, config["retry_policy"]), n_retries)
259+
set_retry_policy(config["retry_policy"], n_retries)
261260

262-
_root_cache_directory = os.path.expanduser(short_cache_dir)
261+
_root_cache_directory = short_cache_dir.expanduser().resolve()
263262
# create the cache subdirectory
264-
if not os.path.exists(_root_cache_directory):
263+
if not _root_cache_directory.exists():
265264
try:
266-
os.makedirs(_root_cache_directory, exist_ok=True)
265+
_root_cache_directory.mkdir(exist_ok=True, parents=True)
267266
except PermissionError:
268267
openml_logger.warning(
269268
"No permission to create openml cache directory at %s! This can result in "
@@ -288,7 +287,7 @@ def set_field_in_config_file(field: str, value: str) -> None:
288287
globals()[field] = value
289288
config_file = determine_config_file_path()
290289
config = _parse_config(str(config_file))
291-
with open(config_file, "w") as fh:
290+
with config_file.open("w") as fh:
292291
for f in _defaults:
293292
# We can't blindly set all values based on globals() because when the user
294293
# sets it through config.FIELD it should not be stored to file.
@@ -303,14 +302,15 @@ def set_field_in_config_file(field: str, value: str) -> None:
303302

304303
def _parse_config(config_file: str | Path) -> dict[str, str]:
305304
"""Parse the config file, set up defaults."""
305+
config_file = Path(config_file)
306306
config = configparser.RawConfigParser(defaults=_defaults)
307307

308308
# The ConfigParser requires a [SECTION_HEADER], which we do not expect in our config file.
309309
# Cheat the ConfigParser module by adding a fake section header
310310
config_file_ = StringIO()
311311
config_file_.write("[FAKE_SECTION]\n")
312312
try:
313-
with open(config_file) as fh:
313+
with config_file.open("w") as fh:
314314
for line in fh:
315315
config_file_.write(line)
316316
except FileNotFoundError:
@@ -326,13 +326,14 @@ def get_config_as_dict() -> dict[str, str | int | bool]:
326326
config = {} # type: Dict[str, Union[str, int, bool]]
327327
config["apikey"] = apikey
328328
config["server"] = server
329-
config["cachedir"] = _root_cache_directory
329+
config["cachedir"] = str(_root_cache_directory)
330330
config["avoid_duplicate_runs"] = avoid_duplicate_runs
331331
config["connection_n_retries"] = connection_n_retries
332332
config["retry_policy"] = retry_policy
333333
return config
334334

335335

336+
# NOTE: For backwards compatibility, we keep the `str`
336337
def get_cache_directory() -> str:
337338
"""Get the current cache directory.
338339
@@ -354,11 +355,11 @@ def get_cache_directory() -> str:
354355
355356
"""
356357
url_suffix = urlparse(server).netloc
357-
reversed_url_suffix = os.sep.join(url_suffix.split(".")[::-1])
358-
return os.path.join(_root_cache_directory, reversed_url_suffix)
358+
reversed_url_suffix = os.sep.join(url_suffix.split(".")[::-1]) # noqa: PTH118
359+
return os.path.join(_root_cache_directory, reversed_url_suffix) # noqa: PTH118
359360

360361

361-
def set_root_cache_directory(root_cache_directory: str) -> None:
362+
def set_root_cache_directory(root_cache_directory: str | Path) -> None:
362363
"""Set module-wide base cache directory.
363364
364365
Sets the root cache directory, wherin the cache directories are
@@ -377,8 +378,8 @@ def set_root_cache_directory(root_cache_directory: str) -> None:
377378
--------
378379
get_cache_directory
379380
"""
380-
global _root_cache_directory
381-
_root_cache_directory = root_cache_directory
381+
global _root_cache_directory # noqa: PLW0603
382+
_root_cache_directory = Path(root_cache_directory)
382383

383384

384385
start_using_configuration_for_example = (

openml/extensions/sklearn/__init__.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,21 @@
11
# License: BSD 3-Clause
2+
from __future__ import annotations
3+
4+
from typing import TYPE_CHECKING
25

36
from openml.extensions import register_extension
47

58
from .extension import SklearnExtension
69

10+
if TYPE_CHECKING:
11+
import pandas as pd
12+
713
__all__ = ["SklearnExtension"]
814

915
register_extension(SklearnExtension)
1016

1117

12-
def cont(X):
18+
def cont(X: pd.DataFrame) -> pd.Series:
1319
"""Returns True for all non-categorical columns, False for the rest.
1420
1521
This is a helper function for OpenML datasets encoded as DataFrames simplifying the handling
@@ -23,7 +29,7 @@ def cont(X):
2329
return X.dtypes != "category"
2430

2531

26-
def cat(X):
32+
def cat(X: pd.DataFrame) -> pd.Series:
2733
"""Returns True for all categorical columns, False for the rest.
2834
2935
This is a helper function for OpenML datasets encoded as DataFrames simplifying the handling

0 commit comments

Comments
 (0)