diff --git a/src/citrine/__version__.py b/src/citrine/__version__.py index 1d3669f91..c1d933105 100644 --- a/src/citrine/__version__.py +++ b/src/citrine/__version__.py @@ -1 +1 @@ -__version__ = "3.25.1" +__version__ = "3.25.2" diff --git a/src/citrine/informatics/executions/predictor_evaluation.py b/src/citrine/informatics/executions/predictor_evaluation.py index ae8270337..f74dcd113 100644 --- a/src/citrine/informatics/executions/predictor_evaluation.py +++ b/src/citrine/informatics/executions/predictor_evaluation.py @@ -2,14 +2,16 @@ from typing import List, Optional, Union from uuid import UUID -from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult -from citrine.informatics.predictor_evaluator import PredictorEvaluator -from citrine.resources.status_detail import StatusDetail +from citrine._rest.asynchronous_object import AsynchronousObject from citrine._rest.engine_resource import EngineResourceWithoutStatus from citrine._rest.resource import PredictorRef from citrine._serialization import properties from citrine._serialization.serializable import Serializable +from citrine._session import Session from citrine._utils.functions import format_escaped_url +from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult +from citrine.informatics.predictor_evaluator import PredictorEvaluator +from citrine.resources.status_detail import StatusDetail class PredictorEvaluatorsResponse(Serializable['EvaluatorsPayload']): @@ -36,7 +38,7 @@ def __init__(self, self.predictor = PredictorRef(predictor_id, predictor_version) -class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation']): +class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation'], AsynchronousObject): """The evaluation of a predictor's performance.""" uid: UUID = properties.UUID('id', serializable=False) @@ -56,6 +58,12 @@ class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation']): default=[], serializable=False) """:List[StatusDetail]: a list of structured status info, containing the message and level""" + project_id: Optional[UUID] = None + _session: Optional[Session] = None + _in_progress_statuses = ["INPROGRESS"] + _succeeded_statuses = ["SUCCEEDED"] + _failed_statuses = ["FAILED"] + def _path(self): return format_escaped_url( '/projects/{project_id}/predictor-evaluations/{evaluation_id}', diff --git a/tests/resources/test_predictor_evaluations.py b/tests/resources/test_predictor_evaluations.py index 73e6789d8..fed853601 100644 --- a/tests/resources/test_predictor_evaluations.py +++ b/tests/resources/test_predictor_evaluations.py @@ -1,3 +1,4 @@ +from copy import deepcopy import uuid import pytest @@ -5,6 +6,7 @@ from citrine.resources.predictor_evaluation import PredictorEvaluationCollection from citrine.informatics.executions.predictor_evaluation import PredictorEvaluationRequest from citrine.informatics.predictors import GraphPredictor +from citrine.jobs.waiting import wait_while_executing from tests.utils.factories import CrossValidationEvaluatorFactory, PredictorEvaluationDataFactory,\ PredictorEvaluationFactory, PredictorInstanceDataFactory, PredictorRefFactory @@ -247,3 +249,26 @@ def test_delete_not_implemented(): pec = PredictorEvaluationCollection(uuid.uuid4(), session) with pytest.raises(NotImplementedError): pec.delete(uuid.uuid4()) + + +def test_wait(): + in_progress_response = PredictorEvaluationFactory(metadata__status={"major": "INPROGRESS", "minor": "EXECUTING", "detail": []}) + completed_response = deepcopy(in_progress_response) + completed_response["metadata"]["status"]["major"] = "SUCCEEDED" + completed_response["metadata"]["status"]["minor"] = "COMPLETED" + + session = FakeSession() + pec = PredictorEvaluationCollection(uuid.uuid4(), session) + + # wait_while_executing makes two additional calls once it's done polling. + responses = 4 * [in_progress_response] + 3 * [completed_response] + session.set_responses(*responses) + + evaluation = pec.build(in_progress_response) + wait_while_executing(collection=pec, execution=evaluation, interval=0.1) + + expected_call = FakeCall( + method='GET', + path=f'/projects/{pec.project_id}/predictor-evaluations/{in_progress_response["id"]}' + ) + assert (len(responses) * [expected_call]) == session.calls