Skip to content

Commit 6231b1c

Browse files
Leaderboard (#1185)
* Implemented `def leaderboard` Still requires testing, only works for classification * Fixed some bugs * Updated function with new params * Cleaned info gathering a little * Identifies if classifier or regressor models * Implemented sort_by param * Added ranking column * Implemented ensemble_only param for leadboard * Implemented param top_k * flake8'd * Created fixtures for use with test_leaderboard * Moved fixtures to conftest, added session scope tmp_dir For the autoML models to be useable for the entire session without training, they require a session scoped tmp_dir. I tried to figure out how to make the tmp_dir more dynamic but documentation seems to imply that the scope is set at *function definition*, not on function call. This means either call the _tmp_dir and manually clean up or just duplicate the tmp_dir function but aptly named for session scope. It's a bit ugly but couldn't find an alternative. * Can't make tmp_dir for session scope fixtures Doesn't populate the request.module object if requesting from a session scope. For now module will have to do * Reverted back, models trained in test * Moved `leaderboard` AutoML -> AutoSklearnEstimator * Added fuzzing test for test_leaderboard * Added tests for leaderboard, added sort_order * Removed Type Final to support python 3.7 * Removed old solution to is_classication for leaderboard * I should really force pre-commit to run before commit (flake8 fixes) * More occurences of Literal * Readded Literal but imported from typing_extensions * Fixed docstring for sphinx * Added make command to build html without running examples * Added doc/examples to gitignore Generating the sphinx examples causes output to be generated in doc/examples. Not sure if this should be pushed considering docs/build is not. * Added leadboard to basic examples Found a bug: /home/skantify/code/auto-sklearn/examples/20_basic/example_multilabel_classification.py failed to execute correctly: Traceback (most recent call last): File "/home/skantify/code/auto-sklearn/examples/20_basic/example_multilabel_classification.py", line 61, in <module> print(automl.leaderboard()) File "/home/skantify/code/auto-sklearn/autosklearn/estimators.py", line 738, in leaderboard model_runs[model_id]['ensemble_weight'] = self.automl_.ensemble_.weights_[i] KeyError: 2 * Cleaned up _str_ of EnsembleSelection * Fixed discrepancy between config_id and model_id There is a discrepency between identifiers used by SMAC and and the identifiers used by an Ensemble class. SMAC uses `config_id` which is available for every run of SMAC while Ensemble uses `model_id == num_run` which is only available in runinfo.additional_info. However, this is not always included in additional_info, nor is additional_info garunteed to exist. Therefore the only garunteed unique identifier for models are `config_id`s which can confuse the user if they wise to interact with the ensembler. * Readded desired code for design choice on model indexing There are two indexes that can be used, SMAC uses `config_id` and asklearn uses `num_run`, these are not garunteed to be equal and also `num_run` is not always present. As the user should not care that there is possible 2 indexes for models, made the choice to show `config_id` as this allows displaying info on failed runs. An alternative to show asklearn's `num_run` index is just to exclude any failed runs from showing up in the leaderboard. * Removed Literal again as typing_extensions is external module * Switched to model_id as primary id Any runs which do not provide a model_id == num_run are essentially discarded. This hsould change in the future but the fix is outside the scope of the PR. * pre-commit flake8 fix * Logger gives warning if sort_by is not in columns asked for * Moved column types to static method * Fixed rank to be based on cost * Fixed so model_id can be requested, even though it always exists * Fixed so rank can be calculated even if cost not requested * Readded Literal and included typing_extension dependancy Once Python 3.7 is dropped, we can drop typing_extensions * Changed default sort_order to 'auto' * Changed leaderboard columns to be static attributes * Update budget doc Co-authored-by: Matthias Feurer <[email protected]> * flake8'd Co-authored-by: Matthias Feurer <[email protected]>
1 parent 611cf5c commit 6231b1c

File tree

12 files changed

+444
-19
lines changed

12 files changed

+444
-19
lines changed

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
# Documentation
22
docs/build/*
3+
docs/examples
34

45
*.py[cod]
56

7+
# Exmaples
8+
# examples 40_advanced generate a tmp_folder
9+
examples/40_advanced/tmp_folder
10+
611
# C extensions
712
*.c
813
*.so

autosklearn/automl.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,6 @@ def __init__(self,
201201
self.cv_models_ = None
202202
self.ensemble_ = None
203203
self._can_predict = False
204-
205204
self._debug_mode = debug_mode
206205

207206
self.InputValidator = None # type: Optional[InputValidator]

autosklearn/ensembles/ensemble_selection.py

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -278,14 +278,20 @@ def predict(self, predictions: Union[np.ndarray, List[np.ndarray]]) -> np.ndarra
278278
return average
279279

280280
def __str__(self) -> str:
281-
return 'Ensemble Selection:\n\tTrajectory: %s\n\tMembers: %s' \
282-
'\n\tWeights: %s\n\tIdentifiers: %s' % \
283-
(' '.join(['%d: %5f' % (idx, performance)
284-
for idx, performance in enumerate(self.trajectory_)]),
285-
self.indices_, self.weights_,
286-
' '.join([str(identifier) for idx, identifier in
287-
enumerate(self.identifiers_)
288-
if self.weights_[idx] > 0]))
281+
trajectory_str = ' '.join([
282+
f'{id}: {perf:.5f}'
283+
for id, perf in enumerate(self.trajectory_)
284+
])
285+
identifiers_str = ' '.join([
286+
f'{identifier}'
287+
for idx, identifier in enumerate(self.identifiers_)
288+
if self.weights_[idx] > 0
289+
])
290+
return ("Ensemble Selection:\n"
291+
f"\tTrajectory: {trajectory_str}\n"
292+
f"\tMembers: {self.indices_}\n"
293+
f"\tWeights: {self.weights_}\n"
294+
f"\tIdentifiers: {identifiers_str}\n")
289295

290296
def get_models_with_weights(
291297
self,

autosklearn/estimators.py

Lines changed: 270 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
# -*- encoding: utf-8 -*-
2-
3-
from typing import Optional, Dict, List, Tuple, Union
2+
from typing import Optional, Dict, List, Tuple, Union, Iterable, ClassVar
3+
from typing_extensions import Literal
44

55
from ConfigSpace.configuration_space import Configuration
66
import dask.distributed
77
import joblib
88
import numpy as np
9+
import pandas as pd
910
from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
1011
from sklearn.utils.multiclass import type_of_target
1112
from smac.runhistory.runhistory import RunInfo, RunValue
@@ -21,6 +22,18 @@
2122

2223

2324
class AutoSklearnEstimator(BaseEstimator):
25+
# Constants used by `def leaderboard` for columns and their sort order
26+
_leaderboard_columns: ClassVar[Dict[str, List[str]]] = {
27+
"all": [
28+
"model_id", "rank", "ensemble_weight", "type", "cost", "duration",
29+
"config_id", "train_loss", "seed", "start_time", "end_time",
30+
"budget", "status", "data_preprocessors", "feature_preprocessors",
31+
"balancing_strategy", "config_origin"
32+
],
33+
"simple": [
34+
"model_id", "rank", "ensemble_weight", "type", "cost", "duration"
35+
]
36+
}
2437

2538
def __init__(
2639
self,
@@ -550,6 +563,261 @@ def sprint_statistics(self):
550563
"""
551564
return self.automl_.sprint_statistics()
552565

566+
def leaderboard(
567+
self,
568+
detailed: bool = False,
569+
ensemble_only: bool = True,
570+
top_k: Union[int, Literal['all']] = 'all',
571+
sort_by: str = 'cost',
572+
sort_order: Literal['auto', 'ascending', 'descending'] = 'auto',
573+
include: Optional[Union[str, Iterable[str]]] = None
574+
) -> pd.DataFrame:
575+
""" Returns a pandas table of results for all evaluated models.
576+
577+
Gives an overview of all models trained during the search process along
578+
with various statistics about their training.
579+
580+
The availble statistics are:
581+
582+
**Simple**:
583+
584+
* ``"model_id"`` - The id given to a model by ``autosklearn``.
585+
* ``"rank"`` - The rank of the model based on it's ``"cost"``.
586+
* ``"ensemble_weight"`` - The weight given to the model in the ensemble.
587+
* ``"type"`` - The type of classifier/regressor used.
588+
* ``"cost"`` - The loss of the model on the validation set.
589+
* ``"duration"`` - Length of time the model was optimized for.
590+
591+
**Detailed**:
592+
The detailed view includes all of the simple statistics along with the
593+
following.
594+
595+
* ``"config_id"`` - The id used by SMAC for optimization.
596+
* ``"budget"`` - How much budget was allocated to this model.
597+
* ``"status"`` - The return status of training the model with SMAC.
598+
* ``"train_loss"`` - The loss of the model on the training set.
599+
* ``"balancing_strategy"`` - The balancing strategy used for data preprocessing.
600+
* ``"start_time"`` - Time the model began being optimized
601+
* ``"end_time"`` - Time the model ended being optimized
602+
* ``"data_preprocessors"`` - The preprocessors used on the data
603+
* ``"feature_preprocessors"`` - The preprocessors for features types
604+
605+
Parameters
606+
----------
607+
detailed: bool = False
608+
Whether to give detailed information or just a simple overview.
609+
610+
ensemble_only: bool = True
611+
Whether to view only models included in the ensemble or all models
612+
trained.
613+
614+
top_k: int or "all" = "all"
615+
How many models to display.
616+
617+
sort_by: str = 'cost'
618+
What column to sort by. If that column is not present, the
619+
sorting defaults to the ``"model_id"`` index column.
620+
621+
sort_order: "auto" or "ascending" or "descending" = "auto"
622+
Which sort order to apply to the ``sort_by`` column. If left
623+
as ``"auto"``, it will sort by a sensible default where "better" is
624+
on top, otherwise defaulting to the pandas default for
625+
`DataFrame.sort_values`_ if there is no obvious "better".
626+
627+
.. _DataFrame.sort_values: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.sort_values.html
628+
629+
include: Optional[str or Iterable[str]]
630+
Items to include, other items not specified will be excluded.
631+
The exception is the ``"model_id"`` index column which is always included.
632+
633+
If left as ``None``, it will resort back to using the ``detailed``
634+
param to decide the columns to include.
635+
636+
Returns
637+
-------
638+
pd.DataFrame
639+
A dataframe of statistics for the models, ordered by ``sort_by``.
640+
641+
""" # noqa (links are too long)
642+
# TODO validate that `self` is fitted. This is required for
643+
# self.ensemble_ to get the identifiers of models it will generate
644+
# weights for.
645+
column_types = {
646+
'all': AutoSklearnEstimator._leaderboard_columns['all'],
647+
'simple': AutoSklearnEstimator._leaderboard_columns['simple'],
648+
'detailed': AutoSklearnEstimator._leaderboard_columns['all']
649+
}
650+
651+
# Validation of top_k
652+
if (
653+
not (isinstance(top_k, str) or isinstance(top_k, int))
654+
or (isinstance(top_k, str) and top_k != 'all')
655+
or (isinstance(top_k, int) and top_k <= 0)
656+
):
657+
raise ValueError(f"top_k={top_k} must be a positive integer or pass"
658+
" `top_k`='all' to view results for all models")
659+
660+
# Validate columns to include
661+
if isinstance(include, str):
662+
include = [include]
663+
664+
if include is not None:
665+
columns = [*include]
666+
667+
# 'model_id' should always be present as it is the unique index
668+
# used for pandas
669+
if 'model_id' not in columns:
670+
columns.append('model_id')
671+
672+
invalid_include_items = set(columns) - set(column_types['all'])
673+
if len(invalid_include_items) != 0:
674+
raise ValueError(f"Values {invalid_include_items} are not known"
675+
f" columns to include, must be contained in "
676+
f"{column_types['all']}")
677+
elif detailed:
678+
columns = column_types['all']
679+
else:
680+
columns = column_types['simple']
681+
682+
# Validation of sorting
683+
if sort_by not in column_types['all']:
684+
raise ValueError(f"sort_by='{sort_by}' must be one of included "
685+
f"columns {set(column_types['all'])}")
686+
687+
valid_sort_orders = ['auto', 'ascending', 'descending']
688+
if not (isinstance(sort_order, str) and sort_order in valid_sort_orders):
689+
raise ValueError(f"`sort_order` = {sort_order} must be a str in "
690+
f"{valid_sort_orders}")
691+
692+
# To get all the models that were optmized, we collect what we can from
693+
# runhistory first.
694+
def has_key(rv, key):
695+
return rv.additional_info and key in rv.additional_info
696+
697+
model_runs = {
698+
rval.additional_info['num_run']: {
699+
'model_id': rval.additional_info['num_run'],
700+
'seed': rkey.seed,
701+
'budget': rkey.budget,
702+
'duration': rval.time,
703+
'config_id': rkey.config_id,
704+
'start_time': rval.starttime,
705+
'end_time': rval.endtime,
706+
'status': str(rval.status),
707+
'cost': rval.cost,
708+
'train_loss': rval.additional_info['train_loss']
709+
if has_key(rval, 'train_loss') else None,
710+
'config_origin': rval.additional_info['configuration_origin']
711+
if has_key(rval, 'configuration_origin') else None
712+
}
713+
for rkey, rval in self.automl_.runhistory_.data.items()
714+
if has_key(rval, 'num_run')
715+
}
716+
717+
# Next we get some info about the model itself
718+
model_class_strings = {
719+
AutoMLClassifier: 'classifier',
720+
AutoMLRegressor: 'regressor'
721+
}
722+
model_type = model_class_strings.get(self._get_automl_class(), None)
723+
if model_type is None:
724+
raise RuntimeError(f"Unknown `automl_class` {self._get_automl_class()}")
725+
726+
# A dict mapping model ids to their configurations
727+
configurations = self.automl_.runhistory_.ids_config
728+
729+
for model_id, run_info in model_runs.items():
730+
config_id = run_info['config_id']
731+
run_config = configurations[config_id]._values
732+
733+
run_info.update({
734+
'balancing_strategy': run_config.get('balancing:strategy', None),
735+
'type': run_config[f'{model_type}:__choice__'],
736+
'data_preprocessors': [
737+
value for key, value in run_config.items()
738+
if 'data_preprocessing' in key and '__choice__' in key
739+
],
740+
'feature_preprocessors': [
741+
value for key, value in run_config.items()
742+
if 'feature_preprocessor' in key and '__choice__' in key
743+
]
744+
})
745+
746+
# Get the models ensemble weight if it has one
747+
# TODO both implementing classes of AbstractEnsemble have a property
748+
# `identifiers_` and `weights_`, might be good to put it as an
749+
# abstract property
750+
# TODO `ensemble_.identifiers_` and `ensemble_.weights_` are loosely
751+
# tied together by ordering, might be better to store as tuple
752+
for i, weight in enumerate(self.automl_.ensemble_.weights_):
753+
(_, model_id, _) = self.automl_.ensemble_.identifiers_[i]
754+
model_runs[model_id]['ensemble_weight'] = weight
755+
756+
# Filter out non-ensemble members if needed, else fill in a default
757+
# value of 0 if it's missing
758+
if ensemble_only:
759+
model_runs = {
760+
model_id: info
761+
for model_id, info in model_runs.items()
762+
if ('ensemble_weight' in info and info['ensemble_weight'] > 0)
763+
}
764+
else:
765+
for model_id, info in model_runs.items():
766+
if 'ensemble_weight' not in info:
767+
info['ensemble_weight'] = 0
768+
769+
# `rank` relies on `cost` so we include `cost`
770+
# We drop it later if it's not requested
771+
if 'rank' in columns and 'cost' not in columns:
772+
columns = [*columns, 'cost']
773+
774+
# Finally, convert into a tabular format by converting the dict into
775+
# column wise orientation.
776+
dataframe = pd.DataFrame({
777+
col: [run_info[col] for run_info in model_runs.values()]
778+
for col in columns if col != 'rank'
779+
})
780+
781+
# Give it an index, even if not in the `include`
782+
dataframe.set_index('model_id', inplace=True)
783+
784+
# Add the `rank` column if needed, dropping `cost` if it's not
785+
# requested by the user
786+
if 'rank' in columns:
787+
dataframe.sort_values(by='cost', ascending=False, inplace=True)
788+
dataframe.insert(column='rank',
789+
value=range(1, len(dataframe) + 1),
790+
loc=list(columns).index('rank'))
791+
792+
if 'cost' not in columns:
793+
dataframe.drop('cost', inplace=True)
794+
795+
# Decide on the sort order depending on what it gets sorted by
796+
descending_columns = ['ensemble_weight', 'duration']
797+
if sort_order == 'auto':
798+
ascending_param = False if sort_by in descending_columns else True
799+
else:
800+
ascending_param = False if sort_order == 'descending' else True
801+
802+
# Sort by the given column name, defaulting to 'model_id' if not present
803+
if sort_by not in dataframe.columns:
804+
self.automl_._logger.warning(f"sort_by = '{sort_by}' was not present"
805+
", defaulting to sort on the index "
806+
"'model_id'")
807+
sort_by = 'model_id'
808+
809+
dataframe.sort_values(by=sort_by,
810+
ascending=ascending_param,
811+
inplace=True)
812+
813+
# Lastly, just grab the top_k
814+
if top_k == 'all' or top_k >= len(dataframe):
815+
top_k = len(dataframe)
816+
817+
dataframe = dataframe.head(top_k)
818+
819+
return dataframe
820+
553821
def _get_automl_class(self):
554822
raise NotImplementedError()
555823

doc/Makefile

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
1919
# the i18n builder cannot share the environment and doctrees with the others
2020
I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) .
2121

22-
.PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
22+
.PHONY: help clean html html-noexamples dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext
2323

2424
all: html
2525

@@ -59,6 +59,12 @@ html:
5959
@echo
6060
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
6161

62+
html-noexamples:
63+
$(SPHINXBUILD) -D plot_gallery=0 -b html $(ALLSPHINXOPTS) $(SOURCEDIR) $(BUILDDIR)/html
64+
@echo
65+
@echo "Build finished. The HTML pages are in $(BUILDDIR)/html."
66+
67+
6268
dirhtml:
6369
$(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml
6470
@echo

examples/20_basic/example_classification.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,12 @@
3232
)
3333
automl.fit(X_train, y_train, dataset_name='breast_cancer')
3434

35+
############################################################################
36+
# View the models found by auto-sklearn
37+
# =====================================
38+
39+
print(automl.leaderboard())
40+
3541
############################################################################
3642
# Print the final ensemble constructed by auto-sklearn
3743
# ====================================================
@@ -44,3 +50,4 @@
4450

4551
predictions = automl.predict(X_test)
4652
print("Accuracy score:", sklearn.metrics.accuracy_score(y_test, predictions))
53+

0 commit comments

Comments
 (0)