diff --git a/durabletask/aio/client.py b/durabletask/aio/client.py index 9b93b96..3c9ab58 100644 --- a/durabletask/aio/client.py +++ b/durabletask/aio/client.py @@ -3,8 +3,10 @@ import logging import uuid +from dataclasses import dataclass from datetime import datetime -from typing import Any, Optional, Sequence, Union +from enum import Enum +from typing import Any, Optional, Sequence, TypeVar, Union import grpc from google.protobuf import wrappers_pb2 @@ -16,13 +18,85 @@ from durabletask import task from durabletask.aio.internal.grpc_interceptor import DefaultClientInterceptorImpl from durabletask.aio.internal.shared import ClientInterceptor, get_grpc_aio_channel -from durabletask.client import ( - OrchestrationState, - OrchestrationStatus, - TInput, - TOutput, - new_orchestration_state, -) + +TInput = TypeVar("TInput") +TOutput = TypeVar("TOutput") + + +class OrchestrationStatus(Enum): + """The status of an orchestration instance.""" + + RUNNING = pb.ORCHESTRATION_STATUS_RUNNING + COMPLETED = pb.ORCHESTRATION_STATUS_COMPLETED + FAILED = pb.ORCHESTRATION_STATUS_FAILED + TERMINATED = pb.ORCHESTRATION_STATUS_TERMINATED + CONTINUED_AS_NEW = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW + PENDING = pb.ORCHESTRATION_STATUS_PENDING + SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED + + def __str__(self): + return helpers.get_orchestration_status_str(self.value) + + +@dataclass +class OrchestrationState: + instance_id: str + name: str + runtime_status: OrchestrationStatus + created_at: datetime + last_updated_at: datetime + serialized_input: Optional[str] + serialized_output: Optional[str] + serialized_custom_status: Optional[str] + failure_details: Optional[task.FailureDetails] + + def raise_if_failed(self): + if self.failure_details is not None: + raise OrchestrationFailedError( + f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}", + self.failure_details, + ) + + +class OrchestrationFailedError(Exception): + def __init__(self, message: str, failure_details: task.FailureDetails): + super().__init__(message) + self._failure_details = failure_details + + @property + def failure_details(self): + return self._failure_details + + +def new_orchestration_state( + instance_id: str, res: pb.GetInstanceResponse +) -> Optional[OrchestrationState]: + if not res.exists: + return None + + state = res.orchestrationState + + failure_details = None + if state.failureDetails.errorMessage != "" or state.failureDetails.errorType != "": + failure_details = task.FailureDetails( + state.failureDetails.errorMessage, + state.failureDetails.errorType, + state.failureDetails.stackTrace.value + if not helpers.is_empty(state.failureDetails.stackTrace) + else None, + ) + + return OrchestrationState( + instance_id, + state.name, + OrchestrationStatus(state.orchestrationStatus), + state.createdTimestamp.ToDatetime(), + state.lastUpdatedTimestamp.ToDatetime(), + state.input.value if not helpers.is_empty(state.input) else None, + state.output.value if not helpers.is_empty(state.output) else None, + state.customStatus.value if not helpers.is_empty(state.customStatus) else None, + failure_details, + ) class AsyncTaskHubGrpcClient: diff --git a/durabletask/client.py b/durabletask/client.py index 1e28f30..7698a9a 100644 --- a/durabletask/client.py +++ b/durabletask/client.py @@ -1,101 +1,22 @@ -# Copyright (c) Microsoft Corporation. +# Copyright (c) The Dapr Authors. # Licensed under the MIT License. import logging -import uuid -from dataclasses import dataclass from datetime import datetime -from enum import Enum -from typing import Any, Optional, Sequence, TypeVar, Union +from functools import partial +from typing import Any, Optional, Sequence, Union -import grpc -from google.protobuf import wrappers_pb2 +from anyio.from_thread import start_blocking_portal # type: ignore -import durabletask.internal.helpers as helpers import durabletask.internal.orchestrator_service_pb2 as pb -import durabletask.internal.orchestrator_service_pb2_grpc as stubs import durabletask.internal.shared as shared from durabletask import task -from durabletask.internal.grpc_interceptor import DefaultClientInterceptorImpl - -TInput = TypeVar("TInput") -TOutput = TypeVar("TOutput") - - -class OrchestrationStatus(Enum): - """The status of an orchestration instance.""" - - RUNNING = pb.ORCHESTRATION_STATUS_RUNNING - COMPLETED = pb.ORCHESTRATION_STATUS_COMPLETED - FAILED = pb.ORCHESTRATION_STATUS_FAILED - TERMINATED = pb.ORCHESTRATION_STATUS_TERMINATED - CONTINUED_AS_NEW = pb.ORCHESTRATION_STATUS_CONTINUED_AS_NEW - PENDING = pb.ORCHESTRATION_STATUS_PENDING - SUSPENDED = pb.ORCHESTRATION_STATUS_SUSPENDED - - def __str__(self): - return helpers.get_orchestration_status_str(self.value) - - -@dataclass -class OrchestrationState: - instance_id: str - name: str - runtime_status: OrchestrationStatus - created_at: datetime - last_updated_at: datetime - serialized_input: Optional[str] - serialized_output: Optional[str] - serialized_custom_status: Optional[str] - failure_details: Optional[task.FailureDetails] - - def raise_if_failed(self): - if self.failure_details is not None: - raise OrchestrationFailedError( - f"Orchestration '{self.instance_id}' failed: {self.failure_details.message}", - self.failure_details, - ) - - -class OrchestrationFailedError(Exception): - def __init__(self, message: str, failure_details: task.FailureDetails): - super().__init__(message) - self._failure_details = failure_details - - @property - def failure_details(self): - return self._failure_details - - -def new_orchestration_state( - instance_id: str, res: pb.GetInstanceResponse -) -> Optional[OrchestrationState]: - if not res.exists: - return None - - state = res.orchestrationState - - failure_details = None - if state.failureDetails.errorMessage != "" or state.failureDetails.errorType != "": - failure_details = task.FailureDetails( - state.failureDetails.errorMessage, - state.failureDetails.errorType, - state.failureDetails.stackTrace.value - if not helpers.is_empty(state.failureDetails.stackTrace) - else None, - ) - - return OrchestrationState( - instance_id, - state.name, - OrchestrationStatus(state.orchestrationStatus), - state.createdTimestamp.ToDatetime(), - state.lastUpdatedTimestamp.ToDatetime(), - state.input.value if not helpers.is_empty(state.input) else None, - state.output.value if not helpers.is_empty(state.output) else None, - state.customStatus.value if not helpers.is_empty(state.customStatus) else None, - failure_details, - ) +from durabletask.aio.client import ( + AsyncTaskHubGrpcClient, + OrchestrationState, + TInput, + TOutput, +) class TaskHubGrpcClient: @@ -110,25 +31,58 @@ def __init__( interceptors: Optional[Sequence[shared.ClientInterceptor]] = None, channel_options: Optional[Sequence[tuple[str, Any]]] = None, ): - # If the caller provided metadata, we need to create a new interceptor for it and - # add it to the list of interceptors. - if interceptors is not None: - interceptors = list(interceptors) - if metadata is not None: - interceptors.append(DefaultClientInterceptorImpl(metadata)) - elif metadata is not None: - interceptors = [DefaultClientInterceptorImpl(metadata)] - else: - interceptors = None - - channel = shared.get_grpc_channel( - host_address=host_address, - secure_channel=secure_channel, - interceptors=interceptors, - options=channel_options, + # Store configuration to construct the async client on demand + self._host_address = host_address + self._metadata = metadata + self._log_handler = log_handler + self._log_formatter = log_formatter + self._secure_channel = secure_channel + self._interceptors = interceptors + self._channel_options = channel_options + + self._portal_cm = start_blocking_portal() + self._portal = self._portal_cm.__enter__() + self._async_client = self._portal.call(self._create_async_client) + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + return False + + def __del__(self): + # Best-effort cleanup in case the user didn't call close() or use context manager + try: + self.close() + except Exception: + pass + + async def _create_async_client(self) -> AsyncTaskHubGrpcClient: + return AsyncTaskHubGrpcClient( + host_address=self._host_address, + metadata=self._metadata, + log_handler=self._log_handler, + log_formatter=self._log_formatter, + secure_channel=self._secure_channel, + interceptors=self._interceptors, + channel_options=self._channel_options, ) - self._stub = stubs.TaskHubSidecarServiceStub(channel) - self._logger = shared.get_logger("client", log_handler, log_formatter) + + def close(self): + if self._async_client is not None: + try: + self._portal.call(self._async_client.aclose) + except Exception: + pass + self._async_client = None + if self._portal is not None: + try: + self._portal_cm.__exit__(None, None, None) + except Exception: + pass + self._portal = None + self._portal_cm = None def schedule_new_orchestration( self, @@ -139,122 +93,87 @@ def schedule_new_orchestration( start_at: Optional[datetime] = None, reuse_id_policy: Optional[pb.OrchestrationIdReusePolicy] = None, ) -> str: - name = orchestrator if isinstance(orchestrator, str) else task.get_name(orchestrator) - - input_pb = ( - wrappers_pb2.StringValue(value=shared.to_json(input)) if input is not None else None - ) - - req = pb.CreateInstanceRequest( - name=name, - instanceId=instance_id if instance_id else uuid.uuid4().hex, - input=input_pb, - scheduledStartTimestamp=helpers.new_timestamp(start_at) if start_at else None, - version=wrappers_pb2.StringValue(value=""), - orchestrationIdReusePolicy=reuse_id_policy, + return self._portal.call( # type: ignore + partial( + self._async_client.schedule_new_orchestration, # type: ignore + orchestrator, + input=input, + instance_id=instance_id, + start_at=start_at, + reuse_id_policy=reuse_id_policy, + ) ) - self._logger.info(f"Starting new '{name}' instance with ID = '{req.instanceId}'.") - res: pb.CreateInstanceResponse = self._stub.StartInstance(req) - return res.instanceId - def get_orchestration_state( self, instance_id: str, *, fetch_payloads: bool = True ) -> Optional[OrchestrationState]: - req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - res: pb.GetInstanceResponse = self._stub.GetInstance(req) - return new_orchestration_state(req.instanceId, res) + return self._portal.call( # type: ignore + partial( + self._async_client.get_orchestration_state, # type: ignore + instance_id, + fetch_payloads=fetch_payloads, + ) + ) def wait_for_orchestration_start( self, instance_id: str, *, fetch_payloads: bool = False, timeout: int = 0 ) -> Optional[OrchestrationState]: - req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - try: - grpc_timeout = None if timeout == 0 else timeout - self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to start." + return self._portal.call( # type: ignore + partial( + self._async_client.wait_for_orchestration_start, # type: ignore + instance_id, + fetch_payloads=fetch_payloads, + timeout=timeout, ) - res: pb.GetInstanceResponse = self._stub.WaitForInstanceStart(req, timeout=grpc_timeout) - return new_orchestration_state(req.instanceId, res) - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore - # Replace gRPC error with the built-in TimeoutError - raise TimeoutError("Timed-out waiting for the orchestration to start") - else: - raise + ) def wait_for_orchestration_completion( self, instance_id: str, *, fetch_payloads: bool = True, timeout: int = 0 ) -> Optional[OrchestrationState]: - req = pb.GetInstanceRequest(instanceId=instance_id, getInputsAndOutputs=fetch_payloads) - try: - grpc_timeout = None if timeout == 0 else timeout - self._logger.info( - f"Waiting {'indefinitely' if timeout == 0 else f'up to {timeout}s'} for instance '{instance_id}' to complete." + return self._portal.call( # type: ignore + partial( + self._async_client.wait_for_orchestration_completion, # type: ignore + instance_id, + fetch_payloads=fetch_payloads, + timeout=timeout, ) - res: pb.GetInstanceResponse = self._stub.WaitForInstanceCompletion( - req, timeout=grpc_timeout - ) - state = new_orchestration_state(req.instanceId, res) - if not state: - return None - - if ( - state.runtime_status == OrchestrationStatus.FAILED - and state.failure_details is not None - ): - details = state.failure_details - self._logger.info( - f"Instance '{instance_id}' failed: [{details.error_type}] {details.message}" - ) - elif state.runtime_status == OrchestrationStatus.TERMINATED: - self._logger.info(f"Instance '{instance_id}' was terminated.") - elif state.runtime_status == OrchestrationStatus.COMPLETED: - self._logger.info(f"Instance '{instance_id}' completed.") - - return state - except grpc.RpcError as rpc_error: - if rpc_error.code() == grpc.StatusCode.DEADLINE_EXCEEDED: # type: ignore - # Replace gRPC error with the built-in TimeoutError - raise TimeoutError("Timed-out waiting for the orchestration to complete") - else: - raise + ) def raise_orchestration_event( self, instance_id: str, event_name: str, *, data: Optional[Any] = None ): - req = pb.RaiseEventRequest( - instanceId=instance_id, - name=event_name, - input=wrappers_pb2.StringValue(value=shared.to_json(data)) if data else None, + self._portal.call( # type: ignore + partial( + self._async_client.raise_orchestration_event, # type: ignore + instance_id, + event_name, + data=data, + ) ) - self._logger.info(f"Raising event '{event_name}' for instance '{instance_id}'.") - self._stub.RaiseEvent(req) - def terminate_orchestration( self, instance_id: str, *, output: Optional[Any] = None, recursive: bool = True ): - req = pb.TerminateRequest( - instanceId=instance_id, - output=wrappers_pb2.StringValue(value=shared.to_json(output)) if output else None, - recursive=recursive, + self._portal.call( # type: ignore + partial( + self._async_client.terminate_orchestration, # type: ignore + instance_id, + output=output, + recursive=recursive, + ) ) - self._logger.info(f"Terminating instance '{instance_id}'.") - self._stub.TerminateInstance(req) - def suspend_orchestration(self, instance_id: str): - req = pb.SuspendRequest(instanceId=instance_id) - self._logger.info(f"Suspending instance '{instance_id}'.") - self._stub.SuspendInstance(req) + self._portal.call(self._async_client.suspend_orchestration, instance_id) # type: ignore def resume_orchestration(self, instance_id: str): - req = pb.ResumeRequest(instanceId=instance_id) - self._logger.info(f"Resuming instance '{instance_id}'.") - self._stub.ResumeInstance(req) + self._portal.call(self._async_client.resume_orchestration, instance_id) # type: ignore def purge_orchestration(self, instance_id: str, recursive: bool = True): - req = pb.PurgeInstancesRequest(instanceId=instance_id, recursive=recursive) - self._logger.info(f"Purging instance '{instance_id}'.") - self._stub.PurgeInstances(req) + self._portal.call( # type: ignore + partial( + self._async_client.purge_orchestration, # type: ignore + instance_id, + recursive=recursive, + ) + ) diff --git a/pyproject.toml b/pyproject.toml index 6626bc2..3cc8d94 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,8 @@ readme = "README.md" dependencies = [ "grpcio", "protobuf>=6.31.1,<7.0.0", # follows grpcio generation version https://github.com/grpc/grpc/blob/v1.75.1/tools/distrib/python/grpcio_tools/setup.py - "asyncio" + "asyncio", + "anyio>=4.0.0,<5" ] [project.urls] diff --git a/tests/durabletask/test_orchestration_e2e.py b/tests/durabletask/test_orchestration_e2e.py index 225456d..fb763c2 100644 --- a/tests/durabletask/test_orchestration_e2e.py +++ b/tests/durabletask/test_orchestration_e2e.py @@ -9,6 +9,7 @@ import pytest from durabletask import client, task, worker +from durabletask.aio.client import OrchestrationStatus # NOTE: These tests assume a sidecar process is running. Example command: # dapr init || true @@ -42,7 +43,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.name == task.get_name(empty_orchestrator) assert state.instance_id == id assert state.failure_details is None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.serialized_input is None assert state.serialized_output is None assert state.serialized_custom_status is None @@ -73,7 +74,7 @@ def sequence(ctx: task.OrchestrationContext, start_val: int): assert state is not None assert state.name == task.get_name(sequence) assert state.instance_id == id - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.failure_details is None assert state.serialized_input == json.dumps(1) assert state.serialized_output == json.dumps([1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]) @@ -117,7 +118,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): assert state is not None assert state.name == task.get_name(orchestrator) assert state.instance_id == id - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.serialized_output == json.dumps("Kah-BOOOOM!!!") assert state.failure_details is None assert state.serialized_custom_status is None @@ -157,7 +158,7 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.failure_details is None assert activity_counter == 30 @@ -183,7 +184,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.serialized_output == json.dumps(["a", "b", "c"]) @@ -211,7 +212,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED if raise_event: assert state.serialized_output == json.dumps("approved") else: @@ -235,11 +236,11 @@ def orchestrator(ctx: task.OrchestrationContext, _): # Suspend the orchestration and wait for it to go into the SUSPENDED state task_hub_client.suspend_orchestration(id) - while state.runtime_status == client.OrchestrationStatus.RUNNING: + while state.runtime_status == OrchestrationStatus.RUNNING: time.sleep(0.1) state = task_hub_client.get_orchestration_state(id) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.SUSPENDED + assert state.runtime_status == OrchestrationStatus.SUSPENDED # Raise an event to the orchestration and confirm that it does NOT complete task_hub_client.raise_orchestration_event(id, "my_event", data=42) @@ -253,7 +254,7 @@ def orchestrator(ctx: task.OrchestrationContext, _): task_hub_client.resume_orchestration(id) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.serialized_output == json.dumps(42) @@ -271,12 +272,12 @@ def orchestrator(ctx: task.OrchestrationContext, _): id = task_hub_client.schedule_new_orchestration(orchestrator) state = task_hub_client.wait_for_orchestration_start(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.RUNNING + assert state.runtime_status == OrchestrationStatus.RUNNING task_hub_client.terminate_orchestration(id, output="some reason for termination") state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.TERMINATED + assert state.runtime_status == OrchestrationStatus.TERMINATED assert state.serialized_output == json.dumps("some reason for termination") @@ -320,7 +321,7 @@ def parent_orchestrator(ctx: task.OrchestrationContext, count: int): metadata = task_hub_client.wait_for_orchestration_completion(instance_id, timeout=30) assert metadata is not None - assert metadata.runtime_status == client.OrchestrationStatus.TERMINATED + assert metadata.runtime_status == OrchestrationStatus.TERMINATED assert metadata.serialized_output == f'"{output}"' time.sleep(delay_time) @@ -365,7 +366,7 @@ def orchestrator(ctx: task.OrchestrationContext, input: int): state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.serialized_output == json.dumps(all_results) assert state.serialized_input == json.dumps(4) assert all_results == [1, 2, 3, 4, 5] @@ -401,7 +402,7 @@ def orchestrator(ctx: task.OrchestrationContext, counter: int): state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED output = json.loads(state.serialized_output) # Should have called activity 3 times with input values 1, 2, 3 @@ -459,7 +460,7 @@ def throw_activity_with_retry(ctx: task.ActivityContext, _): id = task_hub_client.schedule_new_orchestration(parent_orchestrator_with_retry) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.runtime_status == OrchestrationStatus.FAILED assert state.failure_details is not None assert state.failure_details.error_type == "TaskFailedError" assert state.failure_details.message.startswith("Sub-orchestration task #1 failed:") @@ -500,7 +501,7 @@ def throw_activity(ctx: task.ActivityContext, _): id = task_hub_client.schedule_new_orchestration(mock_orchestrator) state = task_hub_client.wait_for_orchestration_completion(id, timeout=30) assert state is not None - assert state.runtime_status == client.OrchestrationStatus.FAILED + assert state.runtime_status == OrchestrationStatus.FAILED assert state.failure_details is not None assert state.failure_details.error_type == "TaskFailedError" assert state.failure_details.message.endswith("Activity task #1 failed: Kah-BOOOOM!!!") @@ -525,7 +526,7 @@ def empty_orchestrator(ctx: task.OrchestrationContext, _): assert state.name == task.get_name(empty_orchestrator) assert state.instance_id == id assert state.failure_details is None - assert state.runtime_status == client.OrchestrationStatus.COMPLETED + assert state.runtime_status == OrchestrationStatus.COMPLETED assert state.serialized_input is None assert state.serialized_output is None assert state.serialized_custom_status == '"foobaz"' diff --git a/tests/durabletask/test_orchestration_e2e_async.py b/tests/durabletask/test_orchestration_e2e_async.py index c441bdc..602fe5b 100644 --- a/tests/durabletask/test_orchestration_e2e_async.py +++ b/tests/durabletask/test_orchestration_e2e_async.py @@ -9,8 +9,7 @@ import pytest from durabletask import task, worker -from durabletask.aio.client import AsyncTaskHubGrpcClient -from durabletask.client import OrchestrationStatus +from durabletask.aio.client import AsyncTaskHubGrpcClient, OrchestrationStatus # NOTE: These tests assume a sidecar process is running. Example command: # go install github.com/dapr/durabletask-go@main diff --git a/tests/durabletask/test_orchestration_wait.py b/tests/durabletask/test_orchestration_wait.py index 49eab0e..065fca8 100644 --- a/tests/durabletask/test_orchestration_wait.py +++ b/tests/durabletask/test_orchestration_wait.py @@ -1,4 +1,4 @@ -from unittest.mock import Mock +from unittest.mock import AsyncMock import pytest @@ -22,15 +22,15 @@ def test_wait_for_orchestration_start_timeout(timeout): response.orchestrationState.CopyFrom(state) c = TaskHubGrpcClient() - c._stub = Mock() - c._stub.WaitForInstanceStart.return_value = response + c._async_client._stub = AsyncMock() + c._async_client._stub.WaitForInstanceStart.return_value = response grpc_timeout = None if timeout is None else timeout c.wait_for_orchestration_start(instance_id, timeout=grpc_timeout) # Verify WaitForInstanceStart was called with timeout=None - c._stub.WaitForInstanceStart.assert_called_once() - _, kwargs = c._stub.WaitForInstanceStart.call_args + c._async_client._stub.WaitForInstanceStart.assert_called_once() + _, kwargs = c._async_client._stub.WaitForInstanceStart.call_args if timeout is None or timeout == 0: assert kwargs.get("timeout") is None else: @@ -54,15 +54,15 @@ def test_wait_for_orchestration_completion_timeout(timeout): response.orchestrationState.CopyFrom(state) c = TaskHubGrpcClient() - c._stub = Mock() - c._stub.WaitForInstanceCompletion.return_value = response + c._async_client._stub = AsyncMock() + c._async_client._stub.WaitForInstanceCompletion.return_value = response grpc_timeout = None if timeout is None else timeout c.wait_for_orchestration_completion(instance_id, timeout=grpc_timeout) # Verify WaitForInstanceStart was called with timeout=None - c._stub.WaitForInstanceCompletion.assert_called_once() - _, kwargs = c._stub.WaitForInstanceCompletion.call_args + c._async_client._stub.WaitForInstanceCompletion.assert_called_once() + _, kwargs = c._async_client._stub.WaitForInstanceCompletion.call_args if timeout is None or timeout == 0: assert kwargs.get("timeout") is None else: