|
9 | 9 | from citrine._rest.resource import Resource |
10 | 10 | from citrine._rest.paginator import Paginator |
11 | 11 | from citrine._serialization import properties |
| 12 | +from citrine._serialization.serializable import Serializable |
12 | 13 | from citrine._session import Session |
13 | 14 | from citrine.informatics.data_sources import DataSource |
14 | 15 | from citrine.informatics.design_candidate import HierarchicalDesignMaterial |
| 16 | +from citrine.informatics.executions.predictor_evaluation import PredictorEvaluation |
| 17 | +from citrine.informatics.predictor_evaluator import PredictorEvaluator |
15 | 18 | from citrine.informatics.predictors import GraphPredictor |
16 | 19 | from citrine.resources.status_detail import StatusDetail |
17 | 20 |
|
18 | 21 |
|
19 | | -# Refers to the most recently edited prediction version. Could be a draft. |
| 22 | +# The most recently edited prediction version. Could be a draft. |
20 | 23 | MOST_RECENT_VER = "most_recent" |
21 | | -LATEST_VER = "latest" # Refers to the highest saved predictor version. |
| 24 | +# The highest trained predictor version. |
| 25 | +LATEST_VER = "latest" |
| 26 | + |
| 27 | + |
| 28 | +class EvaluatorsPayload(Serializable['EvaluatorsPayload']): |
| 29 | + """Container object for predictor evaluators.""" |
| 30 | + |
| 31 | + evaluators = properties.List(properties.Object(PredictorEvaluator), "evaluators") |
| 32 | + |
| 33 | + def __init__(self, evaluators): |
| 34 | + self.evaluators = evaluators |
22 | 35 |
|
23 | 36 |
|
24 | 37 | class AsyncDefaultPredictor(Resource["AsyncDefaultPredictor"]): |
@@ -196,6 +209,38 @@ def rename(self, |
196 | 209 | entity = self.session.put_resource(path, json, version=self._api_version) |
197 | 210 | return self.build(entity) |
198 | 211 |
|
| 212 | + def default_evaluators(self, |
| 213 | + uid: Union[UUID, str], |
| 214 | + *, |
| 215 | + version: Union[int, str]) -> List[PredictorEvaluator]: |
| 216 | + path = self._construct_path(uid, version, action="default-evaluators") |
| 217 | + evaluators = self.session.get_resource(path) |
| 218 | + return EvaluatorsPayload.build(evaluators).evaluators |
| 219 | + |
| 220 | + def _build_predictor_evaluation(self, data): |
| 221 | + evaluation = PredictorEvaluation.build(data) |
| 222 | + evaluation.project_id = self.project_id |
| 223 | + evaluation._session = self.session |
| 224 | + return evaluation |
| 225 | + |
| 226 | + def evaluate(self, |
| 227 | + uid: Union[UUID, str], |
| 228 | + *, |
| 229 | + version: Union[int, str], |
| 230 | + evaluators: List[PredictorEvaluator]) -> PredictorEvaluation: |
| 231 | + path = self._construct_path(uid, version, "evaluate") |
| 232 | + payload = EvaluatorsPayload(evaluators=evaluators).dump() |
| 233 | + response = self.session.post_resource(path, payload, version=self._api_version) |
| 234 | + return self._build_predictor_evaluation(response) |
| 235 | + |
| 236 | + def evaluate_default(self, |
| 237 | + uid: Union[UUID, str], |
| 238 | + *, |
| 239 | + version: Union[int, str]) -> PredictorEvaluation: |
| 240 | + path = self._construct_path(uid, version, "evaluate-default") |
| 241 | + response = self.session.post_resource(path, {}, version=self._api_version) |
| 242 | + return self._build_predictor_evaluation(response) |
| 243 | + |
199 | 244 | def delete(self, uid: Union[UUID, str], *, version: Union[int, str] = MOST_RECENT_VER): |
200 | 245 | """Predictor versions cannot be deleted at this time.""" |
201 | 246 | msg = "Predictor versions cannot be deleted. Use 'archive_version' instead." |
@@ -578,6 +623,95 @@ def rename(self, |
578 | 623 | uid, version=version, name=name, description=description |
579 | 624 | ) |
580 | 625 |
|
| 626 | + def default_evaluators_from_config(self, |
| 627 | + predictor: GraphPredictor) -> List[PredictorEvaluator]: |
| 628 | + """Retrieve the default evaluators for an arbitrary (but valid) predictor config. |
| 629 | +
|
| 630 | + See :func:`~citrine.resources.PredictorCollection.default_evaluators` for details on the |
| 631 | + evaluators. |
| 632 | + """ |
| 633 | + path = self._get_path(action="default-evaluators-from-config") |
| 634 | + payload = predictor.dump() |
| 635 | + evaluators = self.session.post_resource(path, json=payload, version=self._api_version) |
| 636 | + return EvaluatorsPayload.build(evaluators).evaluators |
| 637 | + |
| 638 | + def default_evaluators(self, |
| 639 | + uid: Union[UUID, str], |
| 640 | + *, |
| 641 | + version: Union[int, str] = MOST_RECENT_VER) -> List[PredictorEvaluator]: |
| 642 | + """Retrieve the default evaluators for a stored predictor. |
| 643 | +
|
| 644 | + The current default evaluators perform 5-fold, 3-trial cross-validation on all valid |
| 645 | + predictor responses. Valid responses are those that are **not** produced by the |
| 646 | + following predictors: |
| 647 | +
|
| 648 | + * :class:`~citrine.informatics.predictors.generalized_mean_property_predictor.GeneralizedMeanPropertyPredictor` |
| 649 | + * :class:`~citrine.informatics.predictors.mean_property_predictor.MeanPropertyPredictor` |
| 650 | + * :class:`~citrine.informatics.predictors.ingredient_fractions_predictor.IngredientFractionsPredictor` |
| 651 | + * :class:`~citrine.informatics.predictors.ingredients_to_simple_mixture_predictor.IngredientsToSimpleMixturePredictor` |
| 652 | + * :class:`~citrine.informatics.predictors.ingredients_to_formulation_predictor.IngredientsToFormulationPredictor` |
| 653 | + * :class:`~citrine.informatics.predictors.label_fractions_predictor.LabelFractionsPredictor` |
| 654 | + * :class:`~citrine.informatics.predictors.molecular_structure_featurizer.MolecularStructureFeaturizer` |
| 655 | + * :class:`~citrine.informatics.predictors.simple_mixture_predictor.SimpleMixturePredictor` |
| 656 | +
|
| 657 | + Parameters |
| 658 | + ---------- |
| 659 | + predictor_id: UUID |
| 660 | + Unique identifier of the predictor to evaluate |
| 661 | + predictor_version: Option[Union[int, str]] |
| 662 | + The version of the predictor to evaluate |
| 663 | +
|
| 664 | + Returns |
| 665 | + ------- |
| 666 | + PredictorEvaluation |
| 667 | +
|
| 668 | + """ # noqa: E501,W505 |
| 669 | + return self._versions_collection.default_evaluators(uid, version=version) |
| 670 | + |
| 671 | + def evaluate(self, |
| 672 | + uid: Union[UUID, str], |
| 673 | + *, |
| 674 | + version: Union[int, str] = LATEST_VER, |
| 675 | + evaluators: List[PredictorEvaluator]) -> PredictorEvaluation: |
| 676 | + """Evaluate a predictor using the provided evaluators. |
| 677 | +
|
| 678 | + Parameters |
| 679 | + ---------- |
| 680 | + predictor_id: UUID |
| 681 | + Unique identifier of the predictor to evaluate |
| 682 | + predictor_version: Option[Union[int, str]] |
| 683 | + The version of the predictor to evaluate. Defaults to the latest trained version. |
| 684 | + evaluators: list |
| 685 | +
|
| 686 | + Returns |
| 687 | + ------- |
| 688 | + PredictorEvaluation |
| 689 | +
|
| 690 | + """ |
| 691 | + return self._versions_collection.evaluate(uid, version=version, evaluators=evaluators) |
| 692 | + |
| 693 | + def evaluate_default(self, |
| 694 | + uid: Union[UUID, str], |
| 695 | + *, |
| 696 | + version: Union[int, str] = MOST_RECENT_VER) -> PredictorEvaluation: |
| 697 | + """Evaluate a predictor using the default evaluators. |
| 698 | +
|
| 699 | + See :func:`~citrine.resources.PredictorCollection.default_evaluators` for details on the evaluators. |
| 700 | +
|
| 701 | + Parameters |
| 702 | + ---------- |
| 703 | + predictor_id: UUID |
| 704 | + Unique identifier of the predictor to evaluate |
| 705 | + predictor_version: Option[Union[int, str]] |
| 706 | + The version of the predictor to evaluate |
| 707 | +
|
| 708 | + Returns |
| 709 | + ------- |
| 710 | + PredictorEvaluation |
| 711 | +
|
| 712 | + """ # noqa: E501,W505 |
| 713 | + return self._versions_collection.evaluate_default(uid, version=version) |
| 714 | + |
581 | 715 | def delete(self, uid: Union[UUID, str]): |
582 | 716 | """Predictors cannot be deleted at this time.""" |
583 | 717 | msg = "Predictors cannot be deleted. Use 'archive_version' or 'archive_root' instead." |
|
0 commit comments