Skip to content

Commit 687a0f1

Browse files
authored
Refactor if-statements (#1219)
* Refactor if-statements * Add explicit names to conditional expression * Add 'dependencies' to better mimic OpenMLFlow
1 parent 5dcb7a3 commit 687a0f1

File tree

9 files changed

+37
-66
lines changed

9 files changed

+37
-66
lines changed

openml/_api_calls.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,9 +303,7 @@ def __is_checksum_equal(downloaded_file, md5_checksum=None):
303303
md5 = hashlib.md5()
304304
md5.update(downloaded_file.encode("utf-8"))
305305
md5_checksum_download = md5.hexdigest()
306-
if md5_checksum == md5_checksum_download:
307-
return True
308-
return False
306+
return md5_checksum == md5_checksum_download
309307

310308

311309
def _send_request(request_method, url, data, files=None, md5_checksum=None):

openml/datasets/dataset.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _get_repr_body_fields(self) -> List[Tuple[str, Union[str, int, List[str]]]]:
275275

276276
def __eq__(self, other):
277277

278-
if type(other) != OpenMLDataset:
278+
if not isinstance(other, OpenMLDataset):
279279
return False
280280

281281
server_fields = {
@@ -287,14 +287,12 @@ def __eq__(self, other):
287287
"data_file",
288288
}
289289

290-
# check that the keys are identical
290+
# check that common keys and values are identical
291291
self_keys = set(self.__dict__.keys()) - server_fields
292292
other_keys = set(other.__dict__.keys()) - server_fields
293-
if self_keys != other_keys:
294-
return False
295-
296-
# check that values of the common keys are identical
297-
return all(self.__dict__[key] == other.__dict__[key] for key in self_keys)
293+
return self_keys == other_keys and all(
294+
self.__dict__[key] == other.__dict__[key] for key in self_keys
295+
)
298296

299297
def _download_data(self) -> None:
300298
"""Download ARFF data file to standard cache directory. Set `self.data_file`."""

openml/extensions/sklearn/extension.py

Lines changed: 19 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -38,19 +38,16 @@
3838

3939
logger = logging.getLogger(__name__)
4040

41-
4241
if sys.version_info >= (3, 5):
4342
from json.decoder import JSONDecodeError
4443
else:
4544
JSONDecodeError = ValueError
4645

47-
4846
DEPENDENCIES_PATTERN = re.compile(
4947
r"^(?P<name>[\w\-]+)((?P<operation>==|>=|>)"
5048
r"(?P<version>(\d+\.)?(\d+\.)?(\d+)?(dev)?[0-9]*))?$"
5149
)
5250

53-
5451
SIMPLE_NUMPY_TYPES = [
5552
nptype
5653
for type_cat, nptypes in np.sctypes.items()
@@ -580,15 +577,11 @@ def _is_cross_validator(self, o: Any) -> bool:
580577

581578
@classmethod
582579
def _is_sklearn_flow(cls, flow: OpenMLFlow) -> bool:
583-
if getattr(flow, "dependencies", None) is not None and "sklearn" in flow.dependencies:
584-
return True
585-
if flow.external_version is None:
586-
return False
587-
else:
588-
return (
589-
flow.external_version.startswith("sklearn==")
590-
or ",sklearn==" in flow.external_version
591-
)
580+
sklearn_dependency = isinstance(flow.dependencies, str) and "sklearn" in flow.dependencies
581+
sklearn_as_external = isinstance(flow.external_version, str) and (
582+
flow.external_version.startswith("sklearn==") or ",sklearn==" in flow.external_version
583+
)
584+
return sklearn_dependency or sklearn_as_external
592585

593586
def _get_sklearn_description(self, model: Any, char_lim: int = 1024) -> str:
594587
"""Fetches the sklearn function docstring for the flow description
@@ -1867,24 +1860,22 @@ def is_subcomponent_specification(values):
18671860
# checks whether the current value can be a specification of
18681861
# subcomponents, as for example the value for steps parameter
18691862
# (in Pipeline) or transformers parameter (in
1870-
# ColumnTransformer). These are always lists/tuples of lists/
1871-
# tuples, size bigger than 2 and an OpenMLFlow item involved.
1872-
if not isinstance(values, (tuple, list)):
1873-
return False
1874-
for item in values:
1875-
if not isinstance(item, (tuple, list)):
1876-
return False
1877-
if len(item) < 2:
1878-
return False
1879-
if not isinstance(item[1], (openml.flows.OpenMLFlow, str)):
1880-
if (
1863+
# ColumnTransformer).
1864+
return (
1865+
# Specification requires list/tuple of list/tuple with
1866+
# at least length 2.
1867+
isinstance(values, (tuple, list))
1868+
and all(isinstance(item, (tuple, list)) and len(item) > 1 for item in values)
1869+
# And each component needs to be a flow or interpretable string
1870+
and all(
1871+
isinstance(item[1], openml.flows.OpenMLFlow)
1872+
or (
18811873
isinstance(item[1], str)
18821874
and item[1] in SKLEARN_PIPELINE_STRING_COMPONENTS
1883-
):
1884-
pass
1885-
else:
1886-
return False
1887-
return True
1875+
)
1876+
for item in values
1877+
)
1878+
)
18881879

18891880
# _flow is openml flow object, _param dict maps from flow name to flow
18901881
# id for the main call, the param dict can be overridden (useful for

openml/flows/functions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -261,10 +261,7 @@ def flow_exists(name: str, external_version: str) -> Union[int, bool]:
261261

262262
result_dict = xmltodict.parse(xml_response)
263263
flow_id = int(result_dict["oml:flow_exists"]["oml:id"])
264-
if flow_id > 0:
265-
return flow_id
266-
else:
267-
return False
264+
return flow_id if flow_id > 0 else False
268265

269266

270267
def get_flow_id(

openml/setups/functions.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -55,10 +55,7 @@ def setup_exists(flow) -> int:
5555
)
5656
result_dict = xmltodict.parse(result)
5757
setup_id = int(result_dict["oml:setup_exists"]["oml:id"])
58-
if setup_id > 0:
59-
return setup_id
60-
else:
61-
return False
58+
return setup_id if setup_id > 0 else False
6259

6360

6461
def _get_cached_setup(setup_id):

openml/tasks/split.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,10 @@ def __eq__(self, other):
4747
or self.name != other.name
4848
or self.description != other.description
4949
or self.split.keys() != other.split.keys()
50-
):
51-
return False
52-
53-
if any(
54-
self.split[repetition].keys() != other.split[repetition].keys()
55-
for repetition in self.split
50+
or any(
51+
self.split[repetition].keys() != other.split[repetition].keys()
52+
for repetition in self.split
53+
)
5654
):
5755
return False
5856

openml/utils.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -174,10 +174,7 @@ def _delete_entity(entity_type, entity_id):
174174
url_suffix = "%s/%d" % (entity_type, entity_id)
175175
result_xml = openml._api_calls._perform_api_call(url_suffix, "delete")
176176
result = xmltodict.parse(result_xml)
177-
if "oml:%s_delete" % entity_type in result:
178-
return True
179-
else:
180-
return False
177+
return "oml:%s_delete" % entity_type in result
181178

182179

183180
def _list_all(listing_call, output_format="dict", *args, **filters):

tests/test_extensions/test_functions.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
class DummyFlow:
1111
external_version = "DummyFlow==0.1"
12+
dependencies = None
1213

1314

1415
class DummyModel:
@@ -18,15 +19,11 @@ class DummyModel:
1819
class DummyExtension1:
1920
@staticmethod
2021
def can_handle_flow(flow):
21-
if not inspect.stack()[2].filename.endswith("test_functions.py"):
22-
return False
23-
return True
22+
return inspect.stack()[2].filename.endswith("test_functions.py")
2423

2524
@staticmethod
2625
def can_handle_model(model):
27-
if not inspect.stack()[2].filename.endswith("test_functions.py"):
28-
return False
29-
return True
26+
return inspect.stack()[2].filename.endswith("test_functions.py")
3027

3128

3229
class DummyExtension2:

tests/test_runs/test_run_functions.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ def _wait_for_processed_run(self, run_id, max_waiting_time_seconds):
127127
"evaluated correctly on the server".format(run_id)
128128
)
129129

130-
def _compare_predictions(self, predictions, predictions_prime):
130+
def _assert_predictions_equal(self, predictions, predictions_prime):
131131
self.assertEqual(
132132
np.array(predictions_prime["data"]).shape, np.array(predictions["data"]).shape
133133
)
@@ -151,8 +151,6 @@ def _compare_predictions(self, predictions, predictions_prime):
151151
else:
152152
self.assertEqual(val_1, val_2)
153153

154-
return True
155-
156154
def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed, create_task_obj):
157155
run = openml.runs.get_run(run_id)
158156

@@ -183,7 +181,7 @@ def _rerun_model_and_compare_predictions(self, run_id, model_prime, seed, create
183181

184182
predictions_prime = run_prime._generate_arff_dict()
185183

186-
self._compare_predictions(predictions, predictions_prime)
184+
self._assert_predictions_equal(predictions, predictions_prime)
187185
pd.testing.assert_frame_equal(
188186
run.predictions,
189187
run_prime.predictions,

0 commit comments

Comments
 (0)