Skip to content

Commit 8bac6b2

Browse files
author
Sean Friedowitz
authored
Merge pull request #778 from CitrineInformatics/model-selection-report
Support for model selection summary
2 parents f766171 + 8e1efc5 commit 8bac6b2

File tree

4 files changed

+115
-10
lines changed

4 files changed

+115
-10
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.36.6'
1+
__version__ = '1.37.0'

src/citrine/informatics/reports.py

Lines changed: 70 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
"""Tools for working with reports."""
2-
from typing import Type, Dict, TypeVar, Iterable
2+
from typing import Type, Dict, TypeVar, Iterable, Any, Set
33
from abc import abstractmethod
44
from itertools import groupby
55
from logging import getLogger
@@ -9,6 +9,7 @@
99
from citrine._serialization.serializable import Serializable
1010
from citrine._rest.asynchronous_object import AsynchronousObject
1111
from citrine.informatics.descriptors import Descriptor
12+
from citrine.informatics.predictor_evaluation_result import ResponseMetrics
1213

1314
SelfType = TypeVar('SelfType', bound='Report')
1415

@@ -62,6 +63,58 @@ def __str__(self):
6263
return "<FeatureImportanceReport {!r}>".format(self.output_key) # pragma: no cover
6364

6465

66+
class ModelEvaluationResult(Serializable["ModelEvaluationResult"]):
67+
"""[ALPHA] Settings and evaluation metrics for a single algorithm from AutoML model selection.
68+
69+
ModelEvaluationResult objects are included in a ModelSelectionReport
70+
and should not be user-instantiated.
71+
"""
72+
73+
model_settings = properties.Raw('model_settings')
74+
_response_results = properties.Mapping(
75+
properties.String,
76+
properties.Object(ResponseMetrics),
77+
"response_results"
78+
)
79+
80+
def __init__(self):
81+
pass # pragma: no cover
82+
83+
def __str__(self):
84+
return '<ModelEvaluationResult>' # pragma: no cover
85+
86+
def __getitem__(self, item):
87+
return self._response_results[item]
88+
89+
def __iter__(self):
90+
return iter(self.responses)
91+
92+
@property
93+
def responses(self) -> Set[str]:
94+
"""Responses the model was evaluated on."""
95+
return set(self._response_results.keys())
96+
97+
98+
class ModelSelectionReport(Serializable["ModelSelectionReport"]):
99+
"""[ALPHA] Summary of selection settings and model performance from AutoML model selection.
100+
101+
ModelSelectionReport objects are constructed from saved models and
102+
should not be user-instantiated.
103+
"""
104+
105+
n_folds = properties.Integer('n_folds')
106+
evaluation_results = properties.List(
107+
properties.Object(ModelEvaluationResult),
108+
"evaluation_results"
109+
)
110+
111+
def __init__(self):
112+
pass # pragma: no cover
113+
114+
def __str__(self):
115+
return '<ModelSelectionReport>' # pragma: no cover
116+
117+
65118
class ModelSummary(Serializable['ModelSummary']):
66119
"""[ALPHA] Summary of information about a single model in a predictor.
67120
@@ -87,6 +140,10 @@ class ModelSummary(Serializable['ModelSummary']):
87140
feature_importances = properties.List(
88141
properties.Object(FeatureImportanceReport), 'feature_importances')
89142
""":List[FeatureImportanceReport]: feature importance reports for each output"""
143+
selection_summary = properties.Optional(
144+
properties.Object(ModelSelectionReport), "selection_summary"
145+
)
146+
""":Optional[ModelSelectionReport]: optional results of AutoML model selection"""
90147
predictor_name = properties.String('predictor_configuration_name', default='')
91148
""":str: the name of the predictor that created this model"""
92149
predictor_uid = properties.Optional(properties.UUID(), 'predictor_configuration_uid')
@@ -126,8 +183,13 @@ def __init__(self):
126183
def post_build(self):
127184
"""Modify a PredictorReport object in-place after deserialization."""
128185
self._fill_out_descriptors()
129-
for _, model in enumerate(self.model_summaries):
130-
self._collapse_model_settings(model)
186+
for _, summary in enumerate(self.model_summaries):
187+
# Collapse settings on final trained model
188+
summary.model_settings = self._collapse_model_settings(summary.model_settings)
189+
if summary.selection_summary is not None:
190+
# Collapse settings on any child model evaluation results
191+
for result in summary.selection_summary.evaluation_results:
192+
result.model_settings = self._collapse_model_settings(result.model_settings)
131193

132194
def _fill_out_descriptors(self):
133195
"""Replace descriptor keys in `model_summaries` with full Descriptor objects."""
@@ -146,7 +208,7 @@ def _fill_out_descriptors(self):
146208
model.outputs[j] = descriptor_map[output_key]
147209
except KeyError:
148210
raise RuntimeError("Model {} contains output \'{}\', but no descriptor found "
149-
"with that key".format(model.name, input_key))
211+
"with that key".format(model.name, output_key))
150212

151213
@staticmethod
152214
def _get_sole_descriptor(it: Iterable):
@@ -170,7 +232,7 @@ def _get_sole_descriptor(it: Iterable):
170232
return as_list[0]
171233

172234
@staticmethod
173-
def _collapse_model_settings(model: ModelSummary):
235+
def _collapse_model_settings(raw_settings: Dict[str, Any]):
174236
"""Collapse a model's settings into a flat dictionary.
175237
176238
Model settings are returned as a dictionary with a "name" field, a "value" field,
@@ -187,6 +249,6 @@ def _recurse_model_settings(settings: Dict[str, str], list_or_dict):
187249
settings[list_or_dict['name']] = list_or_dict['value']
188250
_recurse_model_settings(settings, list_or_dict['children'])
189251

190-
settings = dict()
191-
_recurse_model_settings(settings, model.model_settings)
192-
model.model_settings = settings
252+
collapsed = dict()
253+
_recurse_model_settings(collapsed, raw_settings)
254+
return collapsed

tests/conftest.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ def valid_expression_predictor_data():
317317

318318

319319
@pytest.fixture
320-
def valid_predictor_report_data():
320+
def valid_predictor_report_data(example_categorical_pva_metrics, example_f1_metrics):
321321
"""Produce valid data used for tests."""
322322
from citrine.informatics.descriptors import RealDescriptor
323323
x = RealDescriptor("x", lower_bound=0, upper_bound=1, units="")
@@ -355,6 +355,32 @@ def valid_predictor_report_data():
355355
top_features=5
356356
)
357357
],
358+
selection_summary=dict(
359+
n_folds=4,
360+
evaluation_results=[
361+
dict(
362+
model_settings=[
363+
dict(
364+
name='Algorithm',
365+
value='Ensemble of non-linear estimators',
366+
children=[
367+
dict(name='Number of estimators', value=64, children=[]),
368+
dict(name='Leaf model', value='Mean', children=[]),
369+
dict(name='Use jackknife', value=True, children=[])
370+
]
371+
)
372+
],
373+
response_results=dict(
374+
response_name=dict(
375+
metrics=dict(
376+
predicted_vs_actual=example_categorical_pva_metrics,
377+
f1=example_f1_metrics
378+
)
379+
)
380+
)
381+
)
382+
]
383+
),
358384
predictor_configuration_name="Predict y from x with ML"
359385
),
360386
dict(

tests/informatics/test_reports.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,20 @@ def test_status(valid_predictor_report_data):
77
"""Ensure we can check the status of report generation."""
88
report = Report.build(valid_predictor_report_data)
99
assert report.succeeded() and not report.in_progress() and not report.failed()
10+
11+
12+
def test_selection_summary(valid_predictor_report_data):
13+
"""Ensure that we can iterate selection summary results as expected."""
14+
report = PredictorReport.build(valid_predictor_report_data)
15+
selection_summaries = [
16+
s.selection_summary for s in report.model_summaries if s.selection_summary is not None
17+
]
18+
19+
assert len(selection_summaries) > 0
20+
for s in selection_summaries:
21+
assert len(s.evaluation_results) > 0
22+
for result in s.evaluation_results:
23+
assert len(result.model_settings) > 0
24+
for response_key in result:
25+
metrics = result[response_key].metrics
26+
assert len(metrics) > 0

0 commit comments

Comments
 (0)