Skip to content

Commit 53daf7e

Browse files
authored
Leaderboard rank fix (#1191)
* Fixes for valid parameters not being tested * flake8'd
1 parent 6231b1c commit 53daf7e

File tree

2 files changed

+82
-61
lines changed

2 files changed

+82
-61
lines changed

autosklearn/estimators.py

Lines changed: 30 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# -*- encoding: utf-8 -*-
2-
from typing import Optional, Dict, List, Tuple, Union, Iterable, ClassVar
2+
from typing import Optional, Dict, List, Tuple, Union, Iterable
33
from typing_extensions import Literal
44

55
from ConfigSpace.configuration_space import Configuration
@@ -22,18 +22,6 @@
2222

2323

2424
class AutoSklearnEstimator(BaseEstimator):
25-
# Constants used by `def leaderboard` for columns and their sort order
26-
_leaderboard_columns: ClassVar[Dict[str, List[str]]] = {
27-
"all": [
28-
"model_id", "rank", "ensemble_weight", "type", "cost", "duration",
29-
"config_id", "train_loss", "seed", "start_time", "end_time",
30-
"budget", "status", "data_preprocessors", "feature_preprocessors",
31-
"balancing_strategy", "config_origin"
32-
],
33-
"simple": [
34-
"model_id", "rank", "ensemble_weight", "type", "cost", "duration"
35-
]
36-
}
3725

3826
def __init__(
3927
self,
@@ -642,11 +630,7 @@ def leaderboard(
642630
# TODO validate that `self` is fitted. This is required for
643631
# self.ensemble_ to get the identifiers of models it will generate
644632
# weights for.
645-
column_types = {
646-
'all': AutoSklearnEstimator._leaderboard_columns['all'],
647-
'simple': AutoSklearnEstimator._leaderboard_columns['simple'],
648-
'detailed': AutoSklearnEstimator._leaderboard_columns['all']
649-
}
633+
column_types = AutoSklearnEstimator._leaderboard_columns()
650634

651635
# Validation of top_k
652636
if (
@@ -661,6 +645,9 @@ def leaderboard(
661645
if isinstance(include, str):
662646
include = [include]
663647

648+
if include == ['model_id']:
649+
raise ValueError('Must provide more than just `model_id`')
650+
664651
if include is not None:
665652
columns = [*include]
666653

@@ -784,10 +771,10 @@ def has_key(rv, key):
784771
# Add the `rank` column if needed, dropping `cost` if it's not
785772
# requested by the user
786773
if 'rank' in columns:
787-
dataframe.sort_values(by='cost', ascending=False, inplace=True)
774+
dataframe.sort_values(by='cost', ascending=True, inplace=True)
788775
dataframe.insert(column='rank',
789776
value=range(1, len(dataframe) + 1),
790-
loc=list(columns).index('rank'))
777+
loc=list(columns).index('rank') - 1) # account for `model_id`
791778

792779
if 'cost' not in columns:
793780
dataframe.drop('cost', inplace=True)
@@ -806,9 +793,15 @@ def has_key(rv, key):
806793
"'model_id'")
807794
sort_by = 'model_id'
808795

809-
dataframe.sort_values(by=sort_by,
810-
ascending=ascending_param,
811-
inplace=True)
796+
# Cost can be the same but leave rank all over the place
797+
if 'rank' in columns and sort_by == 'cost':
798+
dataframe.sort_values(by=[sort_by, 'rank'],
799+
ascending=[ascending_param, True],
800+
inplace=True)
801+
else:
802+
dataframe.sort_values(by=sort_by,
803+
ascending=ascending_param,
804+
inplace=True)
812805

813806
# Lastly, just grab the top_k
814807
if top_k == 'all' or top_k >= len(dataframe):
@@ -818,6 +811,20 @@ def has_key(rv, key):
818811

819812
return dataframe
820813

814+
@staticmethod
815+
def _leaderboard_columns() -> Dict[Literal['all', 'simple', 'detailed'], List[str]]:
816+
all = [
817+
"model_id", "rank", "ensemble_weight", "type", "cost", "duration",
818+
"config_id", "train_loss", "seed", "start_time", "end_time",
819+
"budget", "status", "data_preprocessors", "feature_preprocessors",
820+
"balancing_strategy", "config_origin"
821+
]
822+
simple = [
823+
"model_id", "rank", "ensemble_weight", "type", "cost", "duration"
824+
]
825+
detailed = all
826+
return {'all': all, 'detailed': detailed, 'simple': simple}
827+
821828
def _get_automl_class(self):
822829
raise NotImplementedError()
823830

test/test_automl/test_estimators.py

Lines changed: 52 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -331,22 +331,9 @@ def test_leaderboard(
331331
):
332332
# Comprehensive test tasks a substantial amount of time, manually set if
333333
# required.
334-
MAX_COMBO_SIZE_FOR_INCLUDE_PARAM = 2 # [0, len(valid_columns) + 1]
334+
MAX_COMBO_SIZE_FOR_INCLUDE_PARAM = 3 # [0, len(valid_columns) + 1]
335+
column_types = AutoSklearnEstimator._leaderboard_columns()
335336

336-
X_train, Y_train, _, _ = putil.get_dataset(dataset_name)
337-
model = estimator_type(
338-
time_left_for_this_task=30,
339-
per_run_time_limit=5,
340-
tmp_folder=tmp_dir,
341-
seed=1
342-
)
343-
model.fit(X_train, Y_train)
344-
345-
column_types = {
346-
'all': AutoSklearnEstimator._leaderboard_columns['all'],
347-
'simple': AutoSklearnEstimator._leaderboard_columns['simple'],
348-
'detailed': AutoSklearnEstimator._leaderboard_columns['all']
349-
}
350337
# Create a dict of all possible param values for each param
351338
# with some invalid one's of the incorrect type
352339
include_combinations = itertools.chain(
@@ -357,7 +344,7 @@ def test_leaderboard(
357344
'detailed': [True, False],
358345
'ensemble_only': [True, False],
359346
'top_k': [-10, 0, 1, 10, 'all'],
360-
'sort_by': [column_types['all'], 'invalid'],
347+
'sort_by': [*column_types['all'], 'invalid'],
361348
'sort_order': ['ascending', 'descending', 'auto', 'invalid', None],
362349
'include': itertools.chain([None, 'invalid', 'type'], include_combinations),
363350
}
@@ -368,7 +355,19 @@ def test_leaderboard(
368355
for param_values in itertools.product(*valid_params.values())
369356
)
370357

358+
X_train, Y_train, _, _ = putil.get_dataset(dataset_name)
359+
model = estimator_type(
360+
time_left_for_this_task=30,
361+
per_run_time_limit=5,
362+
tmp_folder=tmp_dir,
363+
seed=1
364+
)
365+
model.fit(X_train, Y_train)
366+
371367
for params in params_generator:
368+
# Convert from iterator to solid list
369+
if params['include'] is not None and not isinstance(params['include'], str):
370+
params['include'] = list(params['include'])
372371

373372
# Invalid top_k should raise an error, is a positive int or 'all'
374373
if not (params['top_k'] == 'all' or params['top_k'] > 0):
@@ -385,26 +384,32 @@ def test_leaderboard(
385384
with pytest.raises(ValueError):
386385
model.leaderboard(**params)
387386

388-
# Invalid include item in a list
389-
elif params['include'] is not None:
390-
# Crash if just a str but invalid column
391-
if (
392-
isinstance(params['include'], str)
393-
and params['include'] not in column_types['all']
394-
):
395-
with pytest.raises(ValueError):
396-
model.leaderboard(**params)
397-
# Crash if list but contains invalid column
398-
elif (
399-
not isinstance(params['include'], str)
400-
and len(set(params['include']) - set(column_types['all'])) != 0
401-
):
402-
with pytest.raises(ValueError):
403-
model.leaderboard(**params)
387+
# include is single str but not valid
388+
elif (
389+
isinstance(params['include'], str)
390+
and params['include'] not in column_types['all']
391+
):
392+
with pytest.raises(ValueError):
393+
model.leaderboard(**params)
394+
395+
# Crash if include is list but contains invalid column
396+
elif (
397+
isinstance(params['include'], list)
398+
and len(set(params['include']) - set(column_types['all'])) != 0
399+
):
400+
with pytest.raises(ValueError):
401+
model.leaderboard(**params)
402+
403+
# Can't have just model_id, in both single str and list case
404+
elif (
405+
params['include'] == 'model_id'
406+
or params['include'] == ['model_id']
407+
):
408+
with pytest.raises(ValueError):
409+
model.leaderboard(**params)
404410

405-
# Should run without an error if all params are valid
411+
# Else all valid combinations should be validated
406412
else:
407-
# Validate the outputs
408413
leaderboard = model.leaderboard(**params)
409414

410415
# top_k should never be less than the rows given back
@@ -413,14 +418,23 @@ def test_leaderboard(
413418
assert params['top_k'] >= len(leaderboard)
414419

415420
# Check the right columns are present and in the right order
416-
# The id is set as the index but is not included in pandas columns
421+
# The model_id is set as the index, not included in pandas columns
417422
columns = list(leaderboard.columns)
423+
424+
def exclude(lst, s):
425+
return [x for x in lst if x != s]
426+
418427
if params['include'] is not None:
419-
assert columns == list(params['include'])
428+
# Include with only single str should be the only column
429+
if isinstance(params['include'], str):
430+
assert params['include'] in columns and len(columns) == 1
431+
# Include as a list should have all the columns without model_id
432+
else:
433+
assert columns == exclude(params['include'], 'model_id')
420434
elif params['detailed']:
421-
assert columns == column_types['detailed']
435+
assert columns == exclude(column_types['detailed'], 'model_id')
422436
else:
423-
assert columns == column_types['simple']
437+
assert columns == exclude(column_types['simple'], 'model_id')
424438

425439
# Ensure that if it's ensemble only
426440
# Can only check if 'ensemble_weight' is present

0 commit comments

Comments
 (0)