Skip to content

Commit f2484f8

Browse files
committed
PredictorEvaluation should support wait_while_executing.
1 parent 603dcdd commit f2484f8

File tree

3 files changed

+38
-5
lines changed

3 files changed

+38
-5
lines changed

src/citrine/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "3.25.0"
1+
__version__ = "3.25.2"

src/citrine/informatics/executions/predictor_evaluation.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,16 @@
22
from typing import List, Optional, Union
33
from uuid import UUID
44

5-
from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult
6-
from citrine.informatics.predictor_evaluator import PredictorEvaluator
7-
from citrine.resources.status_detail import StatusDetail
5+
from citrine._rest.asynchronous_object import AsynchronousObject
86
from citrine._rest.engine_resource import EngineResourceWithoutStatus
97
from citrine._rest.resource import PredictorRef
108
from citrine._serialization import properties
119
from citrine._serialization.serializable import Serializable
10+
from citrine._session import Session
1211
from citrine._utils.functions import format_escaped_url
12+
from citrine.informatics.predictor_evaluation_result import PredictorEvaluationResult
13+
from citrine.informatics.predictor_evaluator import PredictorEvaluator
14+
from citrine.resources.status_detail import StatusDetail
1315

1416

1517
class PredictorEvaluatorsResponse(Serializable['EvaluatorsPayload']):
@@ -36,7 +38,7 @@ def __init__(self,
3638
self.predictor = PredictorRef(predictor_id, predictor_version)
3739

3840

39-
class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation']):
41+
class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation'], AsynchronousObject):
4042
"""The evaluation of a predictor's performance."""
4143

4244
uid: UUID = properties.UUID('id', serializable=False)
@@ -56,6 +58,12 @@ class PredictorEvaluation(EngineResourceWithoutStatus['PredictorEvaluation']):
5658
default=[], serializable=False)
5759
""":List[StatusDetail]: a list of structured status info, containing the message and level"""
5860

61+
project_id: Optional[UUID] = None
62+
_session: Optional[Session] = None
63+
_in_progress_statuses = ["INPROGRESS"]
64+
_succeeded_statuses = ["SUCCEEDED"]
65+
_failed_statuses = ["FAILED"]
66+
5967
def _path(self):
6068
return format_escaped_url(
6169
'/projects/{project_id}/predictor-evaluations/{evaluation_id}',

tests/resources/test_predictor_evaluations.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
from copy import deepcopy
12
import uuid
23

34
import pytest
45

56
from citrine.resources.predictor_evaluation import PredictorEvaluationCollection
67
from citrine.informatics.executions.predictor_evaluation import PredictorEvaluationRequest
78
from citrine.informatics.predictors import GraphPredictor
9+
from citrine.jobs.waiting import wait_while_executing
810

911
from tests.utils.factories import CrossValidationEvaluatorFactory, PredictorEvaluationDataFactory,\
1012
PredictorEvaluationFactory, PredictorInstanceDataFactory, PredictorRefFactory
@@ -247,3 +249,26 @@ def test_delete_not_implemented():
247249
pec = PredictorEvaluationCollection(uuid.uuid4(), session)
248250
with pytest.raises(NotImplementedError):
249251
pec.delete(uuid.uuid4())
252+
253+
254+
def test_wait():
255+
in_progress_response = PredictorEvaluationFactory(metadata__status={"major": "INPROGRESS", "minor": "EXECUTING", "detail": []})
256+
completed_response = deepcopy(in_progress_response)
257+
completed_response["metadata"]["status"]["major"] = "SUCCEEDED"
258+
completed_response["metadata"]["status"]["minor"] = "COMPLETED"
259+
260+
session = FakeSession()
261+
pec = PredictorEvaluationCollection(uuid.uuid4(), session)
262+
263+
# wait_while_executing makes two additional calls once it's done polling.
264+
responses = 4 * [in_progress_response] + 3 * [completed_response]
265+
session.set_responses(*responses)
266+
267+
evaluation = pec.build(in_progress_response)
268+
wait_while_executing(collection=pec, execution=evaluation, interval=0.1)
269+
270+
expected_call = FakeCall(
271+
method='GET',
272+
path=f'/projects/{pec.project_id}/predictor-evaluations/{in_progress_response["id"]}'
273+
)
274+
assert (len(responses) * [expected_call]) == session.calls

0 commit comments

Comments
 (0)