Skip to content

Commit 984df3d

Browse files
author
bfolie
authored
Merge pull request #768 from CitrineInformatics/PLA-9819/holdout-evaluator
Add holdout evaluator
2 parents ac2c044 + 69d56d4 commit 984df3d

File tree

7 files changed

+186
-44
lines changed

7 files changed

+186
-44
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '1.39.0'
1+
__version__ = '1.40.0'

src/citrine/informatics/predictor_evaluation_result.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from citrine._serialization.polymorphic_serializable import PolymorphicSerializable
55
from citrine._serialization.serializable import Serializable
66
from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric
7-
from citrine.informatics.predictor_evaluator import PredictorEvaluator
8-
7+
from citrine.informatics.predictor_evaluator import PredictorEvaluator, HoldoutSetEvaluator,\
8+
CrossValidationEvaluator
99

1010
__all__ = ['MetricValue',
1111
'RealMetricValue',
@@ -166,6 +166,7 @@ def get_type(cls, data) -> Type[Serializable]:
166166
"""Return the subtype."""
167167
return {
168168
"CrossValidationResult": CrossValidationResult,
169+
"HoldoutSetResult": HoldoutSetResult
169170
}[data["type"]]
170171

171172
@property
@@ -191,11 +192,11 @@ class CrossValidationResult(Serializable["CrossValidationResult"], PredictorEval
191192
where ``cvResult`` is a
192193
:class:`citrine.informatics.predictor_evaluation_result.CrossValidationResult`
193194
and ``'response_name'`` is a response analyzed by a
194-
:class:`citrine.informatics.predictor_evaluator.PredictorEvaluator`.
195+
:class:`citrine.informatics.predictor_evaluator.CrossValidationEvaluator`.
195196
196197
"""
197198

198-
_evaluator = properties.Object(PredictorEvaluator, "evaluator")
199+
_evaluator = properties.Object(CrossValidationEvaluator, "evaluator")
199200
_response_results = properties.Mapping(properties.String, properties.Object(ResponseMetrics),
200201
"response_results")
201202
typ = properties.String('type', default='CrossValidationResult', deserializable=False)
@@ -207,7 +208,45 @@ def __iter__(self):
207208
return iter(self.responses)
208209

209210
@property
210-
def evaluator(self) -> PredictorEvaluator:
211+
def evaluator(self) -> CrossValidationEvaluator:
212+
""":PredictorEvaluator: Evaluator that produced this result."""
213+
return self._evaluator
214+
215+
@property
216+
def responses(self) -> Set[str]:
217+
"""Responses for which results are present."""
218+
return set(self._response_results.keys())
219+
220+
@property
221+
def metrics(self) -> Set[PredictorEvaluationMetric]:
222+
""":Set[PredictorEvaluationMetric]: Metrics for which results are present."""
223+
return self._evaluator.metrics
224+
225+
226+
class HoldoutSetResult(Serializable["HoldoutSetResult"], PredictorEvaluationResult):
227+
"""Result of performing holdout evaluation on a predictor.
228+
229+
Results held-out response can be accessed via ``result['response_name']``,
230+
where ``result`` is a
231+
:class:`citrine.informatics.predictor_evaluation_result.HoldoutSetResult`
232+
and ``'response_name'`` is a response analyzed by a
233+
:class:`citrine.informatics.predictor_evaluator.HoldoutSetEvaluator`.
234+
235+
"""
236+
237+
_evaluator = properties.Object(HoldoutSetEvaluator, "evaluator")
238+
_response_results = properties.Mapping(properties.String, properties.Object(ResponseMetrics),
239+
"response_results")
240+
typ = properties.String('type', default='HoldoutSetResult', deserializable=False)
241+
242+
def __getitem__(self, item):
243+
return self._response_results[item]
244+
245+
def __iter__(self):
246+
return iter(self.responses)
247+
248+
@property
249+
def evaluator(self) -> HoldoutSetEvaluator:
211250
""":PredictorEvaluator: Evaluator that produced this result."""
212251
return self._evaluator
213252

src/citrine/informatics/predictor_evaluator.py

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,12 @@
44
from citrine._serialization.polymorphic_serializable import PolymorphicSerializable
55
from citrine._serialization.serializable import Serializable
66
from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric
7+
from citrine.informatics.data_sources import DataSource
78

89
__all__ = ['PredictorEvaluator',
9-
'CrossValidationEvaluator']
10+
'CrossValidationEvaluator',
11+
'HoldoutSetEvaluator'
12+
]
1013

1114

1215
class PredictorEvaluator(PolymorphicSerializable["PredictorEvaluator"]):
@@ -17,6 +20,7 @@ def get_type(cls, data) -> Type[Serializable]:
1720
"""Return the subtype."""
1821
return {
1922
"CrossValidationEvaluator": CrossValidationEvaluator,
23+
"HoldoutSetEvaluator": HoldoutSetEvaluator
2024
}[data["type"]]
2125

2226
def _attrs(self) -> List[str]:
@@ -129,3 +133,57 @@ def responses(self) -> Set[str]:
129133
def metrics(self) -> Set[PredictorEvaluationMetric]:
130134
"""Set of metrics computed during cross-validation."""
131135
return self._metrics
136+
137+
138+
class HoldoutSetEvaluator(Serializable["HoldoutSetEvaluator"], PredictorEvaluator):
139+
"""Evaluate a predictor using a holdout set.
140+
141+
For each response, the actual values are masked off and the predictor makes predictions.
142+
These predictions are compared with the ground-truth values in the holdout set using
143+
specified metrics.
144+
145+
Parameters
146+
----------
147+
name: str
148+
Name of the evaluator
149+
responses: Set[str]
150+
Set of descriptor keys to evaluate
151+
data_source: DataSource
152+
Source of holdout data
153+
metrics: Optional[Set[PredictorEvaluationMetric]]
154+
Optional set of metrics to compute for each response. Default is all metrics.
155+
156+
"""
157+
158+
def _attrs(self) -> List[str]:
159+
return ["typ", "name", "responses", "data_source", "metrics"]
160+
161+
name = properties.String("name")
162+
description = properties.String("description")
163+
_responses = properties.Set(properties.String, "responses")
164+
data_source = properties.Object(DataSource, "data_source")
165+
_metrics = properties.Optional(properties.Set(properties.Object(PredictorEvaluationMetric)),
166+
"metrics")
167+
typ = properties.String("type", default="HoldoutSetEvaluator", deserializable=False)
168+
169+
def __init__(self,
170+
name: str, *,
171+
description: str = "",
172+
responses: Set[str],
173+
data_source: DataSource,
174+
metrics: Optional[Set[PredictorEvaluationMetric]] = None):
175+
self.name: str = name
176+
self.description: str = description
177+
self._responses: Set[str] = responses
178+
self.data_source = data_source
179+
self._metrics: Optional[Set[PredictorEvaluationMetric]] = metrics
180+
181+
@property
182+
def responses(self) -> Set[str]:
183+
"""Set of responses to predict and compare against the ground-truth values."""
184+
return self._responses
185+
186+
@property
187+
def metrics(self) -> Set[PredictorEvaluationMetric]:
188+
"""Set of metrics computed on the predictions."""
189+
return self._metrics

tests/conftest.py

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -571,8 +571,8 @@ def valid_simple_mixture_predictor_data():
571571
return PredictorEntityDataFactory(data=PredictorDataDataFactory(instance=instance))
572572

573573

574-
@pytest.fixture()
575-
def example_evaluator_dict():
574+
@pytest.fixture
575+
def example_cv_evaluator_dict():
576576
return {
577577
"type": "CrossValidationEvaluator",
578578
"name": "Example evaluator",
@@ -587,6 +587,17 @@ def example_evaluator_dict():
587587
}
588588

589589

590+
@pytest.fixture
591+
def example_holdout_evaluator_dict(valid_gem_data_source_dict):
592+
return {
593+
"type": "HoldoutSetEvaluator",
594+
"name": "Example holdout evaluator",
595+
"description": "",
596+
"responses": ["sweetness"],
597+
"data_source": valid_gem_data_source_dict,
598+
"metrics": [{"type": "RMSE"}]
599+
}
600+
590601
@pytest.fixture()
591602
def example_rmse_metrics():
592603
return {
@@ -652,10 +663,10 @@ def example_categorical_pva_metrics():
652663

653664

654665
@pytest.fixture()
655-
def example_result_dict(example_evaluator_dict, example_rmse_metrics, example_categorical_pva_metrics, example_f1_metrics, example_real_pva_metrics):
666+
def example_cv_result_dict(example_cv_evaluator_dict, example_rmse_metrics, example_categorical_pva_metrics, example_f1_metrics, example_real_pva_metrics):
656667
return {
657668
"type": "CrossValidationResult",
658-
"evaluator": example_evaluator_dict,
669+
"evaluator": example_cv_evaluator_dict,
659670
"response_results": {
660671
"salt?": {
661672
"metrics": {
@@ -673,6 +684,21 @@ def example_result_dict(example_evaluator_dict, example_rmse_metrics, example_ca
673684
}
674685

675686

687+
@pytest.fixture()
688+
def example_holdout_result_dict(example_holdout_evaluator_dict, example_rmse_metrics):
689+
return {
690+
"type": "HoldoutSetResult",
691+
"evaluator": example_holdout_evaluator_dict,
692+
"response_results": {
693+
"sweetness": {
694+
"metrics": {
695+
"rmse": example_rmse_metrics
696+
}
697+
}
698+
}
699+
}
700+
701+
676702
@pytest.fixture()
677703
def example_candidates():
678704
return {
@@ -745,12 +771,12 @@ def design_execution_dict(generic_entity):
745771

746772

747773
@pytest.fixture
748-
def predictor_evaluation_workflow_dict(generic_entity, example_evaluator_dict):
774+
def predictor_evaluation_workflow_dict(generic_entity, example_cv_evaluator_dict, example_holdout_evaluator_dict):
749775
ret = deepcopy(generic_entity)
750776
ret.update({
751777
"name": "Example PEW",
752778
"description": "Example PEW for testing",
753-
"evaluators": [example_evaluator_dict]
779+
"evaluators": [example_cv_evaluator_dict, example_holdout_evaluator_dict]
754780
})
755781
return ret
756782

tests/informatics/test_predictor_evaluation_result.py

Lines changed: 38 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,48 +9,61 @@
99

1010

1111
@pytest.fixture
12-
def example_result(example_result_dict):
13-
return PredictorEvaluationResult.build(example_result_dict)
12+
def example_cv_result(example_cv_result_dict):
13+
return PredictorEvaluationResult.build(example_cv_result_dict)
1414

1515

16-
def test_indexing(example_result):
17-
assert example_result.responses == {"saltiness", "salt?"}
18-
assert example_result.metrics == {RMSE(), PVA(), F1()}
19-
assert set(example_result["salt?"]) == {repr(F1()), repr(PVA())}
20-
assert set(example_result) == {"salt?", "saltiness"}
16+
@pytest.fixture
17+
def example_holdout_result(example_holdout_result_dict):
18+
return PredictorEvaluationResult.build(example_holdout_result_dict)
19+
20+
21+
def test_indexing(example_cv_result, example_holdout_result):
22+
assert example_cv_result.responses == {"saltiness", "salt?"}
23+
assert example_holdout_result.responses == {"sweetness"}
24+
assert example_cv_result.metrics == {RMSE(), PVA(), F1()}
25+
assert example_holdout_result.metrics == {RMSE()}
26+
assert set(example_cv_result["salt?"]) == {repr(F1()), repr(PVA())}
27+
assert set(example_cv_result) == {"salt?", "saltiness"}
28+
assert set(example_holdout_result["sweetness"]) == {repr(RMSE())}
29+
assert set(example_holdout_result) == {"sweetness"}
30+
2131

32+
def test_cv_serde(example_cv_result, example_cv_result_dict):
33+
round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_cv_result_dict)))
34+
assert example_cv_result.evaluator == round_trip.evaluator
2235

23-
def test_serde(example_result, example_result_dict):
24-
round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_result_dict)))
25-
assert example_result.evaluator == round_trip.evaluator
2636

37+
def test_holdout_serde(example_holdout_result, example_holdout_result_dict):
38+
round_trip = PredictorEvaluationResult.build(json.loads(json.dumps(example_holdout_result_dict)))
39+
assert example_holdout_result.evaluator == round_trip.evaluator
2740

28-
def test_evaluator(example_result, example_evaluator_dict):
29-
args = example_evaluator_dict
41+
def test_evaluator(example_cv_result, example_cv_evaluator_dict):
42+
args = example_cv_evaluator_dict
3043
del args["type"]
3144
expected = CrossValidationEvaluator(**args)
32-
assert example_result.evaluator == expected
33-
assert example_result.evaluator != 0 # make sure eq does something for mismatched classes
45+
assert example_cv_result.evaluator == expected
46+
assert example_cv_result.evaluator != 0 # make sure eq does something for mismatched classes
3447

3548

36-
def test_check_rmse(example_result, example_rmse_metrics):
37-
assert example_result["saltiness"]["rmse"].mean == example_rmse_metrics["mean"]
38-
assert example_result["saltiness"][RMSE()].standard_error == example_rmse_metrics["standard_error"]
49+
def test_check_rmse(example_cv_result, example_rmse_metrics):
50+
assert example_cv_result["saltiness"]["rmse"].mean == example_rmse_metrics["mean"]
51+
assert example_cv_result["saltiness"][RMSE()].standard_error == example_rmse_metrics["standard_error"]
3952
# check eq method does something
40-
assert example_result["saltiness"][RMSE()] != 0
53+
assert example_cv_result["saltiness"][RMSE()] != 0
4154
with pytest.raises(TypeError):
42-
foo = example_result["saltiness"][0]
55+
foo = example_cv_result["saltiness"][0]
4356

4457

45-
def test_real_pva(example_result, example_real_pva_metrics):
58+
def test_real_pva(example_cv_result, example_real_pva_metrics):
4659
args = example_real_pva_metrics["value"][0]
4760
expected = PredictedVsActualRealPoint.build(args)
48-
assert example_result["saltiness"]["predicted_vs_actual"][0].predicted == expected.predicted
49-
assert next(iter(example_result["saltiness"]["predicted_vs_actual"])).actual == expected.actual
61+
assert example_cv_result["saltiness"]["predicted_vs_actual"][0].predicted == expected.predicted
62+
assert next(iter(example_cv_result["saltiness"]["predicted_vs_actual"])).actual == expected.actual
5063

5164

52-
def test_categorical_pva(example_result, example_categorical_pva_metrics):
65+
def test_categorical_pva(example_cv_result, example_categorical_pva_metrics):
5366
args = example_categorical_pva_metrics["value"][0]
5467
expected = PredictedVsActualCategoricalPoint.build(args)
55-
assert example_result["salt?"]["predicted_vs_actual"][0].predicted == expected.predicted
56-
assert next(iter(example_result["salt?"]["predicted_vs_actual"])).actual == expected.actual
68+
assert example_cv_result["salt?"]["predicted_vs_actual"][0].predicted == expected.predicted
69+
assert next(iter(example_cv_result["salt?"]["predicted_vs_actual"])).actual == expected.actual

tests/informatics/workflows/test_predictor_evaluation_workflow.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import pytest
2+
import uuid
23

3-
from citrine.informatics.predictor_evaluator import CrossValidationEvaluator, PredictorEvaluator
4+
from citrine.informatics.data_sources import GemTableDataSource
5+
from citrine.informatics.predictor_evaluator import HoldoutSetEvaluator, CrossValidationEvaluator, PredictorEvaluator
46
from citrine.informatics.workflows import PredictorEvaluationWorkflow
57

68

79
@pytest.fixture()
810
def pew():
9-
evaluator = CrossValidationEvaluator(name="test", responses={"foo"})
11+
data_source = GemTableDataSource(table_id=uuid.uuid4(), table_version=3)
12+
evaluator1 = CrossValidationEvaluator(name="test CV", responses={"foo"})
13+
evaluator2 = HoldoutSetEvaluator(name="test holdout", responses={"foo"}, data_source=data_source)
1014
pew = PredictorEvaluationWorkflow(
1115
name="Test",
1216
description="TestWorkflow",
13-
evaluators=[evaluator]
17+
evaluators=[evaluator1, evaluator2]
1418
)
1519
return pew
1620

@@ -20,6 +24,7 @@ def test_round_robin(pew):
2024
assert dumped["name"] == "Test"
2125
assert dumped["description"] == "TestWorkflow"
2226
assert PredictorEvaluator.build(dumped["evaluators"][0]).name == pew.evaluators[0].name
27+
assert PredictorEvaluator.build(dumped["evaluators"][1]).name == pew.evaluators[1].name
2328

2429

2530
def test_print(pew):

tests/resources/test_predictor_evaluation_executions.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,16 @@ def test_build_new_execution(collection, predictor_evaluation_execution_dict):
6161
assert execution.in_progress() and not execution.succeeded() and not execution.failed()
6262

6363

64-
def test_workflow_execution_results(workflow_execution: PredictorEvaluationExecution, session, example_result_dict):
64+
def test_workflow_execution_results(workflow_execution: PredictorEvaluationExecution, session,
65+
example_cv_result_dict):
6566
# Given
66-
session.set_response(example_result_dict)
67+
session.set_response(example_cv_result_dict)
6768

6869
# When
6970
results = workflow_execution["Example Evaluator"]
7071

7172
# Then
72-
assert results.evaluator == PredictorEvaluationResult.build(example_result_dict).evaluator
73+
assert results.evaluator == PredictorEvaluationResult.build(example_cv_result_dict).evaluator
7374
expected_path = '/projects/{}/predictor-evaluation-executions/{}/results'.format(
7475
workflow_execution.project_id,
7576
workflow_execution.uid,

0 commit comments

Comments
 (0)