Skip to content

Commit b1a31ab

Browse files
committed
[PNE-7018] Deprecate PEWs.
PEWs no longer need to be a customer facing asset type. Their only function is to allow reusing the same set of evaluators with different predictors, which A) doesn't often work due to mismatched responses, and B) isn't particularly helpful. As such, we can move users on to a cleaner, easier to use workflow, where predictor evaluation is kicked off directly from a predictor, and you can start one using the default evaluators with a single call. In addition to exposing additional endpoints and adding evaluators to the PEW executions, this also deprecates the predictor_evaluation_workflow property, and other Predictor Evaluation Workflow related assets. It also deprecates properties and classes named something akin to "predictor evaluation execution", in favor of the simpler and clearer "predictor evaluation".
1 parent 019b15d commit b1a31ab

File tree

9 files changed

+367
-21
lines changed

9 files changed

+367
-21
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.22.1"
1+
__version__ = "3.23.0"
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
# flake8: noqa
2-
from .predictor_evaluation_execution import *
32
from .design_execution import *
43
from .generative_design_execution import *
4+
from .predictor_evaluation import *
5+
from .predictor_evaluation_execution import *
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from functools import lru_cache
2+
3+
from citrine.informatics.executions.execution import Execution
4+
from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult
5+
from citrine.informatics.predictor_evaluator import PredictorEvaluator
6+
from citrine._rest.resource import Resource
7+
from citrine._serialization import properties
8+
from citrine._utils.functions import format_escaped_url
9+
10+
11+
class PredictorEvaluation(Resource['PredictorEvaluation'], Execution):
12+
"""The execution of a PredictorEvaluationWorkflow.
13+
14+
Possible statuses are INPROGRESS, SUCCEEDED, and FAILED.
15+
Predictor evaluation executions also have a ``status_description`` field with more information.
16+
17+
"""
18+
19+
evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators",
20+
serializable=False)
21+
""":List{PredictorEvaluator]:the predictor evaluators that were executed. These are used
22+
when calling the ``results()`` method."""
23+
workflow_id = properties.UUID('workflow_id', serializable=False)
24+
""":UUID: Unique identifier of the workflow that was executed"""
25+
predictor_id = properties.UUID('predictor_id', serializable=False)
26+
predictor_version = properties.Integer('predictor_version', serializable=False)
27+
28+
def _path(self):
29+
return format_escaped_url(
30+
'/projects/{project_id}/predictor-evaluation-executions/{execution_id}',
31+
project_id=str(self.project_id),
32+
execution_id=str(self.uid)
33+
)
34+
35+
@lru_cache()
36+
def results(self, evaluator_name: str) -> PredictorEvaluationResult:
37+
"""
38+
Get a specific evaluation result by the name of the evaluator that produced it.
39+
40+
Parameters
41+
----------
42+
evaluator_name: str
43+
Name of the evaluator for which to get the results
44+
45+
Returns
46+
-------
47+
PredictorEvaluationResult
48+
The evaluation result from the evaluator with the given name
49+
50+
"""
51+
params = {"evaluator_name": evaluator_name}
52+
resource = self._session.get_resource(self._path() + "/results", params=params)
53+
return PredictorEvaluationResult.build(resource)
54+
55+
@property
56+
def evaluator_names(self):
57+
"""Names of the predictor evaluators. Used when calling the ``results()`` method."""
58+
return list(iter(self))
59+
60+
def __getitem__(self, item):
61+
if isinstance(item, str):
62+
return self.results(item)
63+
else:
64+
raise TypeError("Results are accessed by string names")
65+
66+
def __iter__(self):
67+
return iter(e.name for e in self.evaluators)

src/citrine/informatics/predictor_evaluator.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
from citrine.informatics.predictor_evaluation_metrics import PredictorEvaluationMetric
77
from citrine.informatics.data_sources import DataSource
88

9-
__all__ = ['PredictorEvaluator',
10-
'CrossValidationEvaluator',
11-
'HoldoutSetEvaluator'
12-
]
9+
__all__ = [
10+
'CrossValidationEvaluator',
11+
'HoldoutSetEvaluator',
12+
'PredictorEvaluator'
13+
]
1314

1415

1516
class PredictorEvaluator(PolymorphicSerializable["PredictorEvaluator"]):

src/citrine/jobs/waiting.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from citrine.informatics.executions.design_execution import DesignExecution
88
from citrine.informatics.executions.generative_design_execution import GenerativeDesignExecution
99
from citrine.informatics.executions.sample_design_space_execution import SampleDesignSpaceExecution
10-
from citrine.informatics.executions import PredictorEvaluationExecution
10+
from citrine.informatics.executions import PredictorEvaluation, PredictorEvaluationExecution
1111

1212

1313
class ConditionTimeoutError(RuntimeError):
@@ -130,12 +130,14 @@ def wait_while_validating(
130130
def wait_while_executing(
131131
*,
132132
collection: Union[
133+
Collection[PredictorEvaluation],
133134
Collection[PredictorEvaluationExecution],
134135
Collection[DesignExecution],
135136
Collection[GenerativeDesignExecution],
136137
Collection[SampleDesignSpaceExecution]
137138
],
138139
execution: Union[
140+
PredictorEvaluation,
139141
PredictorEvaluationExecution,
140142
DesignExecution,
141143
GenerativeDesignExecution,
@@ -145,6 +147,7 @@ def wait_while_executing(
145147
timeout: float = 1800.0,
146148
interval: float = 3.0,
147149
) -> Union[
150+
PredictorEvaluation,
148151
PredictorEvaluationExecution,
149152
DesignExecution,
150153
GenerativeDesignExecution,
@@ -155,20 +158,24 @@ def wait_while_executing(
155158
156159
Parameters
157160
----------
158-
execution : Union[PredictorEvaluationExecution, DesignExecution]
161+
execution : Union[PredictorEvaluation, PredictorEvaluationExecution, DesignExecution,
162+
GenerativeDesignExecution, SampleDesignSpaceExecution]
159163
an execution to monitor
160164
print_status_info : bool, optional
161165
Whether to print status info, by default False
162166
timeout : float, optional
163167
Maximum time spent inquiring in seconds, by default 1800.0
164168
interval : float, optional
165169
Inquiry interval in seconds, by default 3.0
166-
collection : Union[Collection[PredictorEvaluationExecution], Collection[DesignExecution]]
170+
collection : Union[Collection[PredictorEvaluation], Collection[PredictorEvaluationExecution],
171+
Collection[DesignExecution], Collection[GenerativeDesignExecution],
172+
Collection[SampleDesignSpaceExecution]]
167173
Collection containing executions
168174
169175
Returns
170176
-------
171-
Union[PredictorEvaluationExecution, DesignExecution]
177+
execution : Union[PredictorEvaluation, PredictorEvaluationExecution, DesignExecution,
178+
GenerativeDesignExecution, SampleDesignSpaceExecution]
172179
the updated execution after it has finished executing
173180
174181

src/citrine/resources/predictor.py

Lines changed: 136 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,16 +9,29 @@
99
from citrine._rest.resource import Resource
1010
from citrine._rest.paginator import Paginator
1111
from citrine._serialization import properties
12+
from citrine._serialization.serializable import Serializable
1213
from citrine._session import Session
1314
from citrine.informatics.data_sources import DataSource
1415
from citrine.informatics.design_candidate import HierarchicalDesignMaterial
16+
from citrine.informatics.executions.predictor_evaluation import PredictorEvaluation
17+
from citrine.informatics.predictor_evaluator import PredictorEvaluator
1518
from citrine.informatics.predictors import GraphPredictor
1619
from citrine.resources.status_detail import StatusDetail
1720

1821

19-
# Refers to the most recently edited prediction version. Could be a draft.
22+
# The most recently edited prediction version. Could be a draft.
2023
MOST_RECENT_VER = "most_recent"
21-
LATEST_VER = "latest" # Refers to the highest saved predictor version.
24+
# The highest trained predictor version.
25+
LATEST_VER = "latest"
26+
27+
28+
class EvaluatorsPayload(Serializable['EvaluatorsPayload']):
29+
"""Container object for predictor evaluators."""
30+
31+
evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators")
32+
33+
def __init__(self, evaluators):
34+
self.evaluators = evaluators
2235

2336

2437
class AsyncDefaultPredictor(Resource["AsyncDefaultPredictor"]):
@@ -196,6 +209,38 @@ def rename(self,
196209
entity = self.session.put_resource(path, json, version=self._api_version)
197210
return self.build(entity)
198211

212+
def default_evaluators(self,
213+
uid: Union[UUID, str],
214+
*,
215+
version: Union[int, str]) -> List[PredictorEvaluator]:
216+
path = self._construct_path(uid, version, action="default-evaluators")
217+
evaluators = self.session.get_resource(path)
218+
return EvaluatorsPayload.build(evaluators).evaluators
219+
220+
def _build_predictor_evaluation(self, data):
221+
evaluation = PredictorEvaluation.build(data)
222+
evaluation.project_id = self.project_id
223+
evaluation._session = self.session
224+
return evaluation
225+
226+
def evaluate(self,
227+
uid: Union[UUID, str],
228+
*,
229+
version: Union[int, str],
230+
evaluators: List[PredictorEvaluator]) -> PredictorEvaluation:
231+
path = self._construct_path(uid, version, "evaluate")
232+
payload = EvaluatorsPayload(evaluators=evaluators).dump()
233+
response = self.session.post_resource(path, payload, version=self._api_version)
234+
return self._build_predictor_evaluation(response)
235+
236+
def evaluate_default(self,
237+
uid: Union[UUID, str],
238+
*,
239+
version: Union[int, str]) -> PredictorEvaluation:
240+
path = self._construct_path(uid, version, "evaluate-default")
241+
response = self.session.post_resource(path, {}, version=self._api_version)
242+
return self._build_predictor_evaluation(response)
243+
199244
def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER):
200245
"""Predictor versions cannot be deleted at this time."""
201246
msg = "Predictor versions cannot be deleted. Use 'archive_version' instead."
@@ -578,6 +623,95 @@ def rename(self,
578623
uid, version=version, name=name, description=description
579624
)
580625

626+
def default_evaluators_from_config(self,
627+
predictor: GraphPredictor) -> List[PredictorEvaluator]:
628+
"""Retrieve the default evaluators for an arbitrary (but valid) predictor config.
629+
630+
See :func:`~citrine.resources.PredictorCollection.default_evaluators` for details on the
631+
evaluators.
632+
"""
633+
path = self._get_path(action="default-evaluators-from-config")
634+
payload = predictor.dump()
635+
evaluators = self.session.post_resource(path, json=payload, version=self._api_version)
636+
return EvaluatorsPayload.build(evaluators).evaluators
637+
638+
def default_evaluators(self,
639+
uid: Union[UUID, str],
640+
*,
641+
version: Union[int, str] = MOST_RECENT_VER) -> List[PredictorEvaluator]:
642+
"""Retrieve the default evaluators for a stored predictor.
643+
644+
The current default evaluators perform 5-fold, 3-trial cross-validation on all valid
645+
predictor responses. Valid responses are those that are **not** produced by the
646+
following predictors:
647+
648+
* :class:`~citrine.informatics.predictors.generalized_mean_property_predictor.GeneralizedMeanPropertyPredictor`
649+
* :class:`~citrine.informatics.predictors.mean_property_predictor.MeanPropertyPredictor`
650+
* :class:`~citrine.informatics.predictors.ingredient_fractions_predictor.IngredientFractionsPredictor`
651+
* :class:`~citrine.informatics.predictors.ingredients_to_simple_mixture_predictor.IngredientsToSimpleMixturePredictor`
652+
* :class:`~citrine.informatics.predictors.ingredients_to_formulation_predictor.IngredientsToFormulationPredictor`
653+
* :class:`~citrine.informatics.predictors.label_fractions_predictor.LabelFractionsPredictor`
654+
* :class:`~citrine.informatics.predictors.molecular_structure_featurizer.MolecularStructureFeaturizer`
655+
* :class:`~citrine.informatics.predictors.simple_mixture_predictor.SimpleMixturePredictor`
656+
657+
Parameters
658+
----------
659+
predictor_id: UUID
660+
Unique identifier of the predictor to evaluate
661+
predictor_version: Option[Union[int, str]]
662+
The version of the predictor to evaluate
663+
664+
Returns
665+
-------
666+
PredictorEvaluation
667+
668+
""" # noqa: E501,W505
669+
return self._versions_collection.default_evaluators(uid, version=version)
670+
671+
def evaluate(self,
672+
uid: Union[UUID, str],
673+
*,
674+
version: Union[int, str] = LATEST_VER,
675+
evaluators: List[PredictorEvaluator]) -> PredictorEvaluation:
676+
"""Evaluate a predictor using the provided evaluators.
677+
678+
Parameters
679+
----------
680+
predictor_id: UUID
681+
Unique identifier of the predictor to evaluate
682+
predictor_version: Option[Union[int, str]]
683+
The version of the predictor to evaluate. Defaults to the latest trained version.
684+
evaluators: list
685+
686+
Returns
687+
-------
688+
PredictorEvaluation
689+
690+
"""
691+
return self._versions_collection.evaluate(uid, version=version, evaluators=evaluators)
692+
693+
def evaluate_default(self,
694+
uid: Union[UUID, str],
695+
*,
696+
version: Union[int, str] = MOST_RECENT_VER) -> PredictorEvaluation:
697+
"""Evaluate a predictor using the default evaluators.
698+
699+
See :func:`~citrine.resources.PredictorCollection.default_evaluators` for details on the evaluators.
700+
701+
Parameters
702+
----------
703+
predictor_id: UUID
704+
Unique identifier of the predictor to evaluate
705+
predictor_version: Option[Union[int, str]]
706+
The version of the predictor to evaluate
707+
708+
Returns
709+
-------
710+
PredictorEvaluation
711+
712+
""" # noqa: E501,W505
713+
return self._versions_collection.evaluate_default(uid, version=version)
714+
581715
def delete(self, uid: Union[UUID, str]):
582716
"""Predictors cannot be deleted at this time."""
583717
msg = "Predictors cannot be deleted. Use 'archive_version' or 'archive_root' instead."

0 commit comments

Comments
 (0)