Skip to content

Commit a7bcf07

Browse files
Add more type annotations (#1261)
* Add more type annotations * add to progress rst --------- Co-authored-by: Lennart Purucker <[email protected]>
1 parent d6283e8 commit a7bcf07

File tree

12 files changed

+110
-93
lines changed

12 files changed

+110
-93
lines changed

doc/progress.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ next
1111

1212
* MAINT #1280: Use the server-provided ``parquet_url`` instead of ``minio_url`` to determine the location of the parquet file.
1313
* ADD #716: add documentation for remaining attributes of classes and functions.
14+
* ADD #1261: more annotations for type hints.
1415

1516
0.14.1
1617
~~~~~~

openml/base.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
class OpenMLBase(ABC):
1616
"""Base object for functionality that is shared across entities."""
1717

18-
def __repr__(self):
18+
def __repr__(self) -> str:
1919
body_fields = self._get_repr_body_fields()
2020
return self._apply_repr_template(body_fields)
2121

@@ -59,7 +59,9 @@ def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
5959
# Should be implemented in the base class.
6060
pass
6161

62-
def _apply_repr_template(self, body_fields: List[Tuple[str, str]]) -> str:
62+
def _apply_repr_template(
63+
self, body_fields: List[Tuple[str, Union[str, int, List[str]]]]
64+
) -> str:
6365
"""Generates the header and formats the body for string representation of the object.
6466
6567
Parameters
@@ -80,7 +82,7 @@ def _apply_repr_template(self, body_fields: List[Tuple[str, str]]) -> str:
8082
return header + body
8183

8284
@abstractmethod
83-
def _to_dict(self) -> "OrderedDict[str, OrderedDict]":
85+
def _to_dict(self) -> "OrderedDict[str, OrderedDict[str, str]]":
8486
"""Creates a dictionary representation of self.
8587
8688
Uses OrderedDict to ensure consistent ordering when converting to xml.
@@ -107,7 +109,7 @@ def _to_xml(self) -> str:
107109
encoding_specification, xml_body = xml_representation.split("\n", 1)
108110
return xml_body
109111

110-
def _get_file_elements(self) -> Dict:
112+
def _get_file_elements(self) -> openml._api_calls.FILE_ELEMENTS_TYPE:
111113
"""Get file_elements to upload to the server, called during Publish.
112114
113115
Derived child classes should overwrite this method as necessary.
@@ -116,7 +118,7 @@ def _get_file_elements(self) -> Dict:
116118
return {}
117119

118120
@abstractmethod
119-
def _parse_publish_response(self, xml_response: Dict):
121+
def _parse_publish_response(self, xml_response: Dict[str, str]) -> None:
120122
"""Parse the id from the xml_response and assign it to self."""
121123
pass
122124

@@ -135,11 +137,16 @@ def publish(self) -> "OpenMLBase":
135137
self._parse_publish_response(xml_response)
136138
return self
137139

138-
def open_in_browser(self):
140+
def open_in_browser(self) -> None:
139141
"""Opens the OpenML web page corresponding to this object in your default browser."""
140-
webbrowser.open(self.openml_url)
141-
142-
def push_tag(self, tag: str):
142+
if self.openml_url is None:
143+
raise ValueError(
144+
"Cannot open element on OpenML.org when attribute `openml_url` is `None`"
145+
)
146+
else:
147+
webbrowser.open(self.openml_url)
148+
149+
def push_tag(self, tag: str) -> None:
143150
"""Annotates this entity with a tag on the server.
144151
145152
Parameters
@@ -149,7 +156,7 @@ def push_tag(self, tag: str):
149156
"""
150157
_tag_openml_base(self, tag)
151158

152-
def remove_tag(self, tag: str):
159+
def remove_tag(self, tag: str) -> None:
153160
"""Removes a tag from this entity on the server.
154161
155162
Parameters

openml/cli.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def wait_until_valid_input(
5555
return response
5656

5757

58-
def print_configuration():
58+
def print_configuration() -> None:
5959
file = config.determine_config_file_path()
6060
header = f"File '{file}' contains (or defaults to):"
6161
print(header)
@@ -65,7 +65,7 @@ def print_configuration():
6565
print(f"{field.ljust(max_key_length)}: {value}")
6666

6767

68-
def verbose_set(field, value):
68+
def verbose_set(field: str, value: str) -> None:
6969
config.set_field_in_config_file(field, value)
7070
print(f"{field} set to '{value}'.")
7171

@@ -295,7 +295,7 @@ def configure_field(
295295
verbose_set(field, value)
296296

297297

298-
def configure(args: argparse.Namespace):
298+
def configure(args: argparse.Namespace) -> None:
299299
"""Calls the right submenu(s) to edit `args.field` in the configuration file."""
300300
set_functions = {
301301
"apikey": configure_apikey,
@@ -307,7 +307,7 @@ def configure(args: argparse.Namespace):
307307
"verbosity": configure_verbosity,
308308
}
309309

310-
def not_supported_yet(_):
310+
def not_supported_yet(_: str) -> None:
311311
print(f"Setting '{args.field}' is not supported yet.")
312312

313313
if args.field not in ["all", "none"]:

openml/config.py

Lines changed: 34 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import os
1010
from pathlib import Path
1111
import platform
12-
from typing import Tuple, cast, Any, Optional
12+
from typing import Dict, Optional, Tuple, Union, cast
1313
import warnings
1414

1515
from io import StringIO
@@ -19,10 +19,10 @@
1919
logger = logging.getLogger(__name__)
2020
openml_logger = logging.getLogger("openml")
2121
console_handler = None
22-
file_handler = None
22+
file_handler = None # type: Optional[logging.Handler]
2323

2424

25-
def _create_log_handlers(create_file_handler=True):
25+
def _create_log_handlers(create_file_handler: bool = True) -> None:
2626
"""Creates but does not attach the log handlers."""
2727
global console_handler, file_handler
2828
if console_handler is not None or file_handler is not None:
@@ -61,7 +61,7 @@ def _convert_log_levels(log_level: int) -> Tuple[int, int]:
6161
return openml_level, python_level
6262

6363

64-
def _set_level_register_and_store(handler: logging.Handler, log_level: int):
64+
def _set_level_register_and_store(handler: logging.Handler, log_level: int) -> None:
6565
"""Set handler log level, register it if needed, save setting to config file if specified."""
6666
oml_level, py_level = _convert_log_levels(log_level)
6767
handler.setLevel(py_level)
@@ -73,13 +73,13 @@ def _set_level_register_and_store(handler: logging.Handler, log_level: int):
7373
openml_logger.addHandler(handler)
7474

7575

76-
def set_console_log_level(console_output_level: int):
76+
def set_console_log_level(console_output_level: int) -> None:
7777
"""Set console output to the desired level and register it with openml logger if needed."""
7878
global console_handler
7979
_set_level_register_and_store(cast(logging.Handler, console_handler), console_output_level)
8080

8181

82-
def set_file_log_level(file_output_level: int):
82+
def set_file_log_level(file_output_level: int) -> None:
8383
"""Set file output to the desired level and register it with openml logger if needed."""
8484
global file_handler
8585
_set_level_register_and_store(cast(logging.Handler, file_handler), file_output_level)
@@ -139,7 +139,8 @@ def set_retry_policy(value: str, n_retries: Optional[int] = None) -> None:
139139

140140
if value not in default_retries_by_policy:
141141
raise ValueError(
142-
f"Detected retry_policy '{value}' but must be one of {default_retries_by_policy}"
142+
f"Detected retry_policy '{value}' but must be one of "
143+
f"{list(default_retries_by_policy.keys())}"
143144
)
144145
if n_retries is not None and not isinstance(n_retries, int):
145146
raise TypeError(f"`n_retries` must be of type `int` or `None` but is `{type(n_retries)}`.")
@@ -160,7 +161,7 @@ class ConfigurationForExamples:
160161
_test_apikey = "c0c42819af31e706efe1f4b88c23c6c1"
161162

162163
@classmethod
163-
def start_using_configuration_for_example(cls):
164+
def start_using_configuration_for_example(cls) -> None:
164165
"""Sets the configuration to connect to the test server with valid apikey.
165166
166167
To configuration as was before this call is stored, and can be recovered
@@ -187,7 +188,7 @@ def start_using_configuration_for_example(cls):
187188
)
188189

189190
@classmethod
190-
def stop_using_configuration_for_example(cls):
191+
def stop_using_configuration_for_example(cls) -> None:
191192
"""Return to configuration as it was before `start_use_example_configuration`."""
192193
if not cls._start_last_called:
193194
# We don't want to allow this because it will (likely) result in the `server` and
@@ -200,8 +201,8 @@ def stop_using_configuration_for_example(cls):
200201
global server
201202
global apikey
202203

203-
server = cls._last_used_server
204-
apikey = cls._last_used_key
204+
server = cast(str, cls._last_used_server)
205+
apikey = cast(str, cls._last_used_key)
205206
cls._start_last_called = False
206207

207208

@@ -215,7 +216,7 @@ def determine_config_file_path() -> Path:
215216
return config_dir / "config"
216217

217218

218-
def _setup(config=None):
219+
def _setup(config: Optional[Dict[str, Union[str, int, bool]]] = None) -> None:
219220
"""Setup openml package. Called on first import.
220221
221222
Reads the config file and sets up apikey, server, cache appropriately.
@@ -243,28 +244,22 @@ def _setup(config=None):
243244
cache_exists = True
244245

245246
if config is None:
246-
config = _parse_config(config_file)
247+
config = cast(Dict[str, Union[str, int, bool]], _parse_config(config_file))
248+
config = cast(Dict[str, Union[str, int, bool]], config)
247249

248-
def _get(config, key):
249-
return config.get("FAKE_SECTION", key)
250+
avoid_duplicate_runs = bool(config.get("avoid_duplicate_runs"))
250251

251-
avoid_duplicate_runs = config.getboolean("FAKE_SECTION", "avoid_duplicate_runs")
252-
else:
253-
254-
def _get(config, key):
255-
return config.get(key)
256-
257-
avoid_duplicate_runs = config.get("avoid_duplicate_runs")
252+
apikey = cast(str, config["apikey"])
253+
server = cast(str, config["server"])
254+
short_cache_dir = cast(str, config["cachedir"])
258255

259-
apikey = _get(config, "apikey")
260-
server = _get(config, "server")
261-
short_cache_dir = _get(config, "cachedir")
262-
263-
n_retries = _get(config, "connection_n_retries")
264-
if n_retries is not None:
265-
n_retries = int(n_retries)
256+
tmp_n_retries = config["connection_n_retries"]
257+
if tmp_n_retries is not None:
258+
n_retries = int(tmp_n_retries)
259+
else:
260+
n_retries = None
266261

267-
set_retry_policy(_get(config, "retry_policy"), n_retries)
262+
set_retry_policy(cast(str, config["retry_policy"]), n_retries)
268263

269264
_root_cache_directory = os.path.expanduser(short_cache_dir)
270265
# create the cache subdirectory
@@ -287,10 +282,10 @@ def _get(config, key):
287282
)
288283

289284

290-
def set_field_in_config_file(field: str, value: Any):
285+
def set_field_in_config_file(field: str, value: str) -> None:
291286
"""Overwrites the `field` in the configuration file with the new `value`."""
292287
if field not in _defaults:
293-
return ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.")
288+
raise ValueError(f"Field '{field}' is not valid and must be one of '{_defaults.keys()}'.")
294289

295290
globals()[field] = value
296291
config_file = determine_config_file_path()
@@ -308,7 +303,7 @@ def set_field_in_config_file(field: str, value: Any):
308303
fh.write(f"{f} = {value}\n")
309304

310305

311-
def _parse_config(config_file: str):
306+
def _parse_config(config_file: Union[str, Path]) -> Dict[str, str]:
312307
"""Parse the config file, set up defaults."""
313308
config = configparser.RawConfigParser(defaults=_defaults)
314309

@@ -326,11 +321,12 @@ def _parse_config(config_file: str):
326321
logger.info("Error opening file %s: %s", config_file, e.args[0])
327322
config_file_.seek(0)
328323
config.read_file(config_file_)
329-
return config
324+
config_as_dict = {key: value for key, value in config.items("FAKE_SECTION")}
325+
return config_as_dict
330326

331327

332-
def get_config_as_dict():
333-
config = dict()
328+
def get_config_as_dict() -> Dict[str, Union[str, int, bool]]:
329+
config = dict() # type: Dict[str, Union[str, int, bool]]
334330
config["apikey"] = apikey
335331
config["server"] = server
336332
config["cachedir"] = _root_cache_directory
@@ -340,7 +336,7 @@ def get_config_as_dict():
340336
return config
341337

342338

343-
def get_cache_directory():
339+
def get_cache_directory() -> str:
344340
"""Get the current cache directory.
345341
346342
This gets the cache directory for the current server relative
@@ -366,7 +362,7 @@ def get_cache_directory():
366362
return _cachedir
367363

368364

369-
def set_root_cache_directory(root_cache_directory):
365+
def set_root_cache_directory(root_cache_directory: str) -> None:
370366
"""Set module-wide base cache directory.
371367
372368
Sets the root cache directory, wherin the cache directories are

openml/exceptions.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# License: BSD 3-Clause
22

3-
from typing import Optional
3+
from typing import Optional, Set
44

55

66
class PyOpenMLError(Exception):
@@ -28,7 +28,7 @@ def __init__(self, message: str, code: Optional[int] = None, url: Optional[str]
2828
self.url = url
2929
super().__init__(message)
3030

31-
def __str__(self):
31+
def __str__(self) -> str:
3232
return f"{self.url} returned code {self.code}: {self.message}"
3333

3434

@@ -59,7 +59,7 @@ class OpenMLPrivateDatasetError(PyOpenMLError):
5959
class OpenMLRunsExistError(PyOpenMLError):
6060
"""Indicates run(s) already exists on the server when they should not be duplicated."""
6161

62-
def __init__(self, run_ids: set, message: str):
62+
def __init__(self, run_ids: Set[int], message: str) -> None:
6363
if len(run_ids) < 1:
6464
raise ValueError("Set of run ids must be non-empty.")
6565
self.run_ids = run_ids

0 commit comments

Comments
 (0)