Skip to content

Commit de7d659

Browse files
committed
[PNE-7018] Add new evaluations API.
Working on adding tests now. Need to add evaluators to the factory file, then use it for better testing. Also need to deprecate the PEW stuff.
1 parent 019b15d commit de7d659

File tree

8 files changed

+361
-11
lines changed

8 files changed

+361
-11
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.22.1"
1+
__version__ = "3.23.0"

src/citrine/_rest/engine_resource.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,14 @@ def is_archived(self):
4242
return self.archived_by is not None
4343

4444
def _post_dump(self, data: dict) -> dict:
45-
# Only the data portion of an entity is sent to the server.
46-
data = data["data"]
47-
48-
if "instance" in data:
49-
# Currently, name and description exists on both the data envelope and the config.
50-
data["instance"]["name"] = data["name"]
51-
data["instance"]["description"] = data["description"]
45+
if data:
46+
# Only the data portion of an entity is sent to the server.
47+
data = data["data"]
48+
49+
if "instance" in data:
50+
# Currently, name and description exists on both the data envelope and the config.
51+
data["instance"]["name"] = data["name"]
52+
data["instance"]["description"] = data["description"]
5253

5354
return super()._post_dump(data)
5455

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
from functools import lru_cache
2+
from typing import List, Optional, Union
3+
from uuid import UUID
4+
5+
from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult
6+
from citrine.informatics.predictor_evaluator import PredictorEvaluator
7+
from citrine.resources.status_detail import StatusDetail
8+
from citrine._rest.engine_resource import EngineResourceWithoutStatus
9+
from citrine._rest.resource import PredictorRef
10+
from citrine._serialization import properties
11+
from citrine._serialization.serializable import Serializable
12+
from citrine._utils.functions import format_escaped_url
13+
14+
15+
class PredictorEvaluatorsResponse(Serializable['EvaluatorsPayload']):
16+
"""Container object for a default predictor evaluator response."""
17+
18+
evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators")
19+
20+
def __init__(self, evaluators: List[PredictorEvaluator]):
21+
self.evaluators = evaluators
22+
23+
24+
class PredictorEvaluationRequest(Serializable['EvaluatorsPayload']):
25+
"""Container object for a predictor evaluation request."""
26+
27+
predictor = properties.Object(PredictorRef, "predictor")
28+
evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators")
29+
30+
def __init__(self,
31+
*,
32+
evaluators: List[PredictorEvaluator],
33+
predictor_id: Union[UUID, str],
34+
predictor_version: Optional[Union[int, str]] = None):
35+
self.evaluators = evaluators
36+
self.predictor = PredictorRef(predictor_id, predictor_version)
37+
38+
39+
class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation']):
40+
"""The evaluation of a predictor's performance."""
41+
42+
uid: UUID = properties.UUID('id', serializable=False)
43+
""":UUID: Unique identifier of the evaluation"""
44+
evaluators = properties.List(properties.Object(PredictorEvaluator), "data.evaluators",
45+
serializable=False)
46+
""":List{PredictorEvaluator]:the predictor evaluators that were executed. These are used
47+
when calling the ``results()`` method."""
48+
predictor_id = properties.UUID('metadata.predictor_id', serializable=False)
49+
""":UUID:"""
50+
predictor_version = properties.Integer('metadata.predictor_version', serializable=False)
51+
status = properties.String('metadata.status.major', serializable=False)
52+
""":str: short description of the evaluation's status"""
53+
status_description = properties.String('metadata.status.minor', serializable=False)
54+
""":str: more detailed description of the evaluation's status"""
55+
status_detail = properties.List(properties.Object(StatusDetail), 'metadata.status.detail',
56+
default=[], serializable=False)
57+
""":List[StatusDetail]: a list of structured status info, containing the message and level"""
58+
59+
def _path(self):
60+
return format_escaped_url(
61+
'/projects/{project_id}/predictor-evaluations/{evaluation_id}',
62+
project_id=str(self.project_id),
63+
evaluation_id=str(self.uid)
64+
)
65+
66+
@lru_cache()
67+
def results(self, evaluator_name: str) -> PredictorEvaluationResult:
68+
"""
69+
Get a specific evaluation result by the name of the evaluator that produced it.
70+
71+
Parameters
72+
----------
73+
evaluator_name: str
74+
Name of the evaluator for which to get the results
75+
76+
Returns
77+
-------
78+
PredictorEvaluationResult
79+
The evaluation result from the evaluator with the given name
80+
81+
"""
82+
params = {"evaluator_name": evaluator_name}
83+
resource = self._session.get_resource(self._path() + "/results", params=params)
84+
return PredictorEvaluationResult.build(resource)
85+
86+
@property
87+
def evaluator_names(self):
88+
"""Names of the predictor evaluators. Used when calling the ``results()`` method."""
89+
return list(iter(self))
90+
91+
def __getitem__(self, item):
92+
if isinstance(item, str):
93+
return self.results(item)
94+
else:
95+
raise TypeError("Results are accessed by string names")
96+
97+
def __iter__(self):
98+
return iter(e.name for e in self.evaluators)

src/citrine/resources/design_space.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,15 +129,15 @@ def _list_base(self, *, per_page: int = 100, archived: Optional[bool] = None):
129129
per_page=per_page)
130130

131131
def list_all(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
132-
"""List the most recent version of all design spaces."""
132+
"""List all design spaces."""
133133
return self._list_base(per_page=per_page)
134134

135135
def list(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
136-
"""List the most recent version of all non-archived design spaces."""
136+
"""List non-archived design spaces."""
137137
return self._list_base(per_page=per_page, archived=False)
138138

139139
def list_archived(self, *, per_page: int = 20) -> Iterable[DesignSpace]:
140-
"""List the most recent version of all archived predictors."""
140+
"""List archived design spaces."""
141141
return self._list_base(per_page=per_page, archived=True)
142142

143143
def create_default(self,
Lines changed: 219 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,219 @@
1+
from functools import partial
2+
from typing import Iterable, Iterator, List, Optional, Union
3+
from uuid import UUID
4+
5+
from citrine.informatics.executions.predictor_evaluation import PredictorEvaluation, \
6+
PredictorEvaluationRequest, PredictorEvaluatorsResponse
7+
from citrine.informatics.predictor_evaluator import PredictorEvaluator
8+
from citrine.informatics.predictors import GraphPredictor
9+
from citrine.resources.predictor import LATEST_VER as LATEST_PRED_VER
10+
from citrine._rest.collection import Collection
11+
from citrine._rest.resource import PredictorRef
12+
from citrine._session import Session
13+
14+
15+
class PredictorEvaluationCollection(Collection[PredictorEvaluation]):
16+
"""Represents the collection of predictor evaluations.
17+
18+
Parameters
19+
----------
20+
project_id: UUID
21+
the UUID of the project
22+
23+
"""
24+
25+
_api_version = 'v1'
26+
_path_template = '/projects/{project_id}/predictor-evaluations'
27+
_individual_key = None
28+
_resource = PredictorEvaluation
29+
_collection_key = 'response'
30+
31+
def __init__(self, project_id: UUID, session: Session):
32+
self.project_id = project_id
33+
self.session: Session = session
34+
35+
def build(self, data: dict) -> PredictorEvaluation:
36+
"""Build an individual predictor evaluation."""
37+
evaluation = PredictorEvaluation.build(data)
38+
evaluation._session = self.session
39+
evaluation._project_id = self.project_id
40+
return evaluation
41+
42+
def _list_base(self,
43+
*,
44+
per_page: int = 100,
45+
predictor_id: Optional[UUID] = None,
46+
predictor_version: Optional[Union[int, str]] = None,
47+
archived: Optional[bool] = None
48+
) -> Iterator[PredictorEvaluation]:
49+
params = {"archived": archived}
50+
if predictor_id is not None:
51+
params["predictor_id"] = str(predictor_id)
52+
if predictor_version is not None:
53+
params["predictor_version"] = predictor_version
54+
55+
fetcher = partial(self._fetch_page, additional_params=params)
56+
return self._paginator.paginate(page_fetcher=fetcher,
57+
collection_builder=self._build_collection_elements,
58+
per_page=per_page)
59+
60+
def list_all(self,
61+
*,
62+
per_page: int = 100,
63+
predictor_id: Optional[UUID] = None,
64+
predictor_version: Optional[Union[int, str]] = None
65+
) -> Iterable[PredictorEvaluation]:
66+
"""List all predictor evaluations."""
67+
return self._list_base(per_page=per_page,
68+
predictor_id=predictor_id,
69+
predictor_version=predictor_version)
70+
71+
def list(self,
72+
*,
73+
per_page: int = 100,
74+
predictor_id: Optional[UUID] = None,
75+
predictor_version: Optional[Union[int, str]] = None
76+
) -> Iterable[PredictorEvaluation]:
77+
"""List non-archived predictor evaluations."""
78+
return self._list_base(per_page=per_page,
79+
predictor_id=predictor_id,
80+
predictor_version=predictor_version,
81+
archived=False)
82+
83+
def list_archived(self,
84+
*,
85+
per_page: int = 100,
86+
predictor_id: Optional[UUID] = None,
87+
predictor_version: Optional[Union[int, str]] = None
88+
) -> Iterable[PredictorEvaluation]:
89+
"""List archived predictor evaluations."""
90+
return self._list_base(per_page=per_page,
91+
predictor_id=predictor_id,
92+
predictor_version=predictor_version,
93+
archived=True)
94+
95+
def archive(self, uid: Union[UUID, str]):
96+
"""Archive an evaluation."""
97+
url = self._get_path(uid, action="archive")
98+
result = self.session.put_resource(url, {}, version=self._api_version)
99+
return self.build(result)
100+
101+
def restore(self, uid: Union[UUID, str]):
102+
"""Restore an archived evaluation."""
103+
url = self._get_path(uid, action="restore")
104+
result = self.session.put_resource(url, {}, version=self._api_version)
105+
return self.build(result)
106+
107+
def default_from_config(self, config: GraphPredictor) -> List[PredictorEvaluator]:
108+
"""Retrieve the default evaluators for an arbitrary (but valid) predictor config.
109+
110+
See :func:`~citrine.resources.PredictorEvaluationCollection.default_evaluators` for details
111+
on the resulting evaluators.
112+
"""
113+
path = self._get_path(action="default-from-config")
114+
payload = config.dump()["instance"]
115+
result = self.session.post_resource(path, json=payload, version=self._api_version)
116+
return PredictorEvaluatorsResponse.build(result).evaluators
117+
118+
def default(self,
119+
*,
120+
predictor_id: Union[UUID, str],
121+
predictor_version: Union[int, str] = LATEST_PRED_VER
122+
) -> List[PredictorEvaluator]:
123+
"""Retrieve the default evaluators for a stored predictor.
124+
125+
The current default evaluators perform 5-fold, 3-trial cross-validation on all valid
126+
predictor responses. Valid responses are those that are **not** produced by the
127+
following predictors:
128+
129+
* :class:`~citrine.informatics.predictors.generalized_mean_property_predictor.GeneralizedMeanPropertyPredictor`
130+
* :class:`~citrine.informatics.predictors.mean_property_predictor.MeanPropertyPredictor`
131+
* :class:`~citrine.informatics.predictors.ingredient_fractions_predictor.IngredientFractionsPredictor`
132+
* :class:`~citrine.informatics.predictors.ingredients_to_simple_mixture_predictor.IngredientsToSimpleMixturePredictor`
133+
* :class:`~citrine.informatics.predictors.ingredients_to_formulation_predictor.IngredientsToFormulationPredictor`
134+
* :class:`~citrine.informatics.predictors.label_fractions_predictor.LabelFractionsPredictor`
135+
* :class:`~citrine.informatics.predictors.molecular_structure_featurizer.MolecularStructureFeaturizer`
136+
* :class:`~citrine.informatics.predictors.simple_mixture_predictor.SimpleMixturePredictor`
137+
138+
Parameters
139+
----------
140+
predictor_id: UUID
141+
Unique identifier of the predictor to evaluate
142+
predictor_version: Option[Union[int, str]]
143+
The version of the predictor to evaluate
144+
145+
Returns
146+
-------
147+
PredictorEvaluation
148+
149+
""" # noqa: E501,W505
150+
path = self._get_path(action="default")
151+
payload = PredictorRef(uid=predictor_id, version=predictor_version).dump()
152+
result = self.session.post_resource(path, json=payload, version=self._api_version)
153+
return PredictorEvaluatorsResponse.build(result).evaluators
154+
155+
def trigger(self,
156+
*,
157+
predictor_id: Union[UUID, str],
158+
predictor_version: Union[int, str] = LATEST_PRED_VER,
159+
evaluators: List[PredictorEvaluator]) -> PredictorEvaluation:
160+
"""Evaluate a predictor using the provided evaluators.
161+
162+
Parameters
163+
----------
164+
predictor_id: UUID
165+
Unique identifier of the predictor to evaluate
166+
predictor_version: Option[Union[int, str]]
167+
The version of the predictor to evaluate. Defaults to the latest trained version.
168+
evaluators: List[PredictorEvaluator]
169+
The evaluators to use to measure predictor performance.
170+
171+
Returns
172+
-------
173+
PredictorEvaluation
174+
175+
"""
176+
path = self._get_path("trigger")
177+
payload = PredictorEvaluationRequest(evaluators=evaluators,
178+
predictor_id=predictor_id,
179+
predictor_version=predictor_version).dump()
180+
result = self.session.post_resource(path, payload, version=self._api_version)
181+
return self.build(result)
182+
183+
def trigger_default(self,
184+
*,
185+
predictor_id: Union[UUID, str],
186+
predictor_version: Union[int, str] = LATEST_PRED_VER
187+
) -> PredictorEvaluation:
188+
"""Evaluate a predictor using the default evaluators.
189+
190+
See :func:`~citrine.resources.PredictorCollection.default_evaluators` for details on the evaluators.
191+
192+
Parameters
193+
----------
194+
predictor_id: UUID
195+
Unique identifier of the predictor to evaluate
196+
predictor_version: Option[Union[int, str]]
197+
The version of the predictor to evaluate
198+
199+
Returns
200+
-------
201+
PredictorEvaluation
202+
203+
""" # noqa: E501,W505
204+
path = self._get_path("trigger-default")
205+
payload = PredictorRef(uid=predictor_id, version=predictor_version).dump()
206+
result = self.session.post_resource(path, json=payload, version=self._api_version)
207+
return self.build(result)
208+
209+
def register(self, model: PredictorEvaluation) -> PredictorEvaluation:
210+
"""Cannot register an evaluation."""
211+
raise NotImplementedError("Cannot register a PredictorEvaluation.")
212+
213+
def update(self, model: PredictorEvaluation) -> PredictorEvaluation:
214+
"""Cannot update an evaluation."""
215+
raise NotImplementedError("Cannot update a PredictorEvaluation.")
216+
217+
def delete(self, uid: Union[UUID, str]):
218+
"""Cannot delete an evaluation."""
219+
raise NotImplementedError("Cannot delete a PredictorEvaluation.")

src/citrine/resources/project.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
PredictorEvaluationExecutionCollection
4141
from citrine.resources.predictor_evaluation_workflow import \
4242
PredictorEvaluationWorkflowCollection
43+
from citrine.resources.predictor_evaluation import PredictorEvaluationCollection
4344
from citrine.resources.generative_design_execution import \
4445
GenerativeDesignExecutionCollection
4546
from citrine.resources.project_member import ProjectMember
@@ -148,6 +149,11 @@ def predictor_evaluation_executions(self) -> PredictorEvaluationExecutionCollect
148149
"""Return a collection representing all visible predictor evaluation executions."""
149150
return PredictorEvaluationExecutionCollection(project_id=self.uid, session=self.session)
150151

152+
@property
153+
def predictor_evaluations(self) -> PredictorEvaluationCollection:
154+
"""Return a collection representing all visible predictor evaluations."""
155+
return PredictorEvaluationCollection(project_id=self.uid, session=self.session)
156+
151157
@property
152158
def design_workflows(self) -> DesignWorkflowCollection:
153159
"""Return a collection representing all visible design workflows."""

tests/resources/test_project.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,10 @@ def test_pe_executions_get_project_id(project):
314314
project.predictor_evaluation_executions.trigger(uuid.uuid4())
315315

316316

317+
def test_predictor_evaluations_get_project_id(project):
318+
assert project.uid == project.predictor_evaluations.project_id
319+
320+
317321
def test_design_workflows_get_project_id(project):
318322
assert project.uid == project.design_workflows.project_id
319323

0 commit comments

Comments
 (0)