diff --git a/dapr/aio/clients/grpc/client.py b/dapr/aio/clients/grpc/client.py index 028eaef5..d363775f 100644 --- a/dapr/aio/clients/grpc/client.py +++ b/dapr/aio/clients/grpc/client.py @@ -1906,7 +1906,7 @@ async def wait(self, timeout_s: float): remaining = (start + timeout_s) - time.time() if remaining < 0: raise e - asyncio.sleep(min(1, remaining)) + await asyncio.sleep(min(1, remaining)) async def get_metadata(self) -> GetMetadataResponse: """Returns information about the sidecar allowing for runtime diff --git a/dapr/aio/clients/grpc/subscription.py b/dapr/aio/clients/grpc/subscription.py index fff74f16..32c544a2 100644 --- a/dapr/aio/clients/grpc/subscription.py +++ b/dapr/aio/clients/grpc/subscription.py @@ -4,7 +4,7 @@ from grpc.aio import AioRpcError from dapr.clients.grpc._response import TopicEventResponse -from dapr.clients.health import DaprHealth +from dapr.aio.clients.health import DaprHealth from dapr.common.pubsub.subscription import ( StreamCancelledError, StreamInactiveError, @@ -52,7 +52,7 @@ async def outgoing_request_iterator(): async def reconnect_stream(self): await self.close() - DaprHealth.wait_for_sidecar() + await DaprHealth.wait_for_sidecar() print('Attempting to reconnect...') await self.start() @@ -67,7 +67,7 @@ async def next_message(self): return None return SubscriptionMessage(message.event_message) except AioRpcError as e: - if e.code() == StatusCode.UNAVAILABLE: + if e.code() == StatusCode.UNAVAILABLE or e.code() == StatusCode.UNKNOWN: print( f'gRPC error while reading from stream: {e.details()}, ' f'Status Code: {e.code()}. ' diff --git a/dapr/aio/clients/health.py b/dapr/aio/clients/health.py new file mode 100644 index 00000000..9ab66ebb --- /dev/null +++ b/dapr/aio/clients/health.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2024 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import time + +import aiohttp + +from dapr.clients.http.conf import DAPR_API_TOKEN_HEADER, DAPR_USER_AGENT, USER_AGENT_HEADER +from dapr.clients.http.helpers import get_api_url +from dapr.conf import settings + + +class DaprHealth: + @staticmethod + async def wait_for_sidecar(): + health_url = f'{get_api_url()}/healthz/outbound' + headers = {USER_AGENT_HEADER: DAPR_USER_AGENT} + if settings.DAPR_API_TOKEN is not None: + headers[DAPR_API_TOKEN_HEADER] = settings.DAPR_API_TOKEN + timeout = float(settings.DAPR_HEALTH_TIMEOUT) + + start = time.time() + ssl_context = DaprHealth.get_ssl_context() + + connector = aiohttp.TCPConnector(ssl=ssl_context) + async with aiohttp.ClientSession(connector=connector) as session: + while True: + try: + async with session.get(health_url, headers=headers) as response: + if 200 <= response.status < 300: + break + except aiohttp.ClientError as e: + print(f'Health check on {health_url} failed: {e}') + except Exception as e: + print(f'Unexpected error during health check: {e}') + + remaining = (start + timeout) - time.time() + if remaining <= 0: + raise TimeoutError(f'Dapr health check timed out, after {timeout}.') + await asyncio.sleep(min(1, remaining)) + + @staticmethod + def get_ssl_context(): + # This method is used (overwritten) from tests + # to return context for self-signed certificates + return None diff --git a/tests/clients/test_dapr_grpc_client_async_secure.py b/tests/clients/test_dapr_grpc_client_async_secure.py index a49fe5fc..1d685287 100644 --- a/tests/clients/test_dapr_grpc_client_async_secure.py +++ b/tests/clients/test_dapr_grpc_client_async_secure.py @@ -17,14 +17,15 @@ from unittest.mock import patch from dapr.aio.clients.grpc.client import DaprGrpcClientAsync +from dapr.aio.clients.health import DaprHealth as DaprHealthAsync from dapr.clients.health import DaprHealth from dapr.conf import settings from tests.clients.certs import replacement_get_credentials_func, replacement_get_health_context from tests.clients.test_dapr_grpc_client_async import DaprGrpcClientAsyncTests - from .fake_dapr_server import FakeDaprSidecar DaprGrpcClientAsync.get_credentials = replacement_get_credentials_func +DaprHealthAsync.get_ssl_context = replacement_get_health_context DaprHealth.get_ssl_context = replacement_get_health_context diff --git a/tests/clients/test_heatlhcheck.py b/tests/clients/test_healthcheck.py similarity index 100% rename from tests/clients/test_heatlhcheck.py rename to tests/clients/test_healthcheck.py diff --git a/tests/clients/test_healthcheck_async.py b/tests/clients/test_healthcheck_async.py new file mode 100644 index 00000000..66876873 --- /dev/null +++ b/tests/clients/test_healthcheck_async.py @@ -0,0 +1,197 @@ +# -*- coding: utf-8 -*- + +""" +Copyright 2025 The Dapr Authors +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +""" + +import asyncio +import time +import unittest +from unittest.mock import AsyncMock, MagicMock, patch + +from dapr.aio.clients.health import DaprHealth +from dapr.conf import settings +from dapr.version import __version__ + + +class DaprHealthCheckAsyncTests(unittest.IsolatedAsyncioTestCase): + @patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500') + @patch('aiohttp.ClientSession.get') + async def test_wait_for_sidecar_success(self, mock_get): + # Create mock response + mock_response = MagicMock() + mock_response.status = 200 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_get.return_value = mock_response + + try: + await DaprHealth.wait_for_sidecar() + except Exception as e: + self.fail(f'wait_for_sidecar() raised an exception unexpectedly: {e}') + + mock_get.assert_called_once() + + # Check URL + called_url = mock_get.call_args[0][0] + self.assertEqual(called_url, 'http://domain.com:3500/v1.0/healthz/outbound') + + # Check headers are properly set + headers = mock_get.call_args[1]['headers'] + self.assertIn('User-Agent', headers) + self.assertEqual(headers['User-Agent'], f'dapr-sdk-python/{__version__}') + + @patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500') + @patch.object(settings, 'DAPR_API_TOKEN', 'mytoken') + @patch('aiohttp.ClientSession.get') + async def test_wait_for_sidecar_success_with_api_token(self, mock_get): + # Create mock response + mock_response = MagicMock() + mock_response.status = 200 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_get.return_value = mock_response + + try: + await DaprHealth.wait_for_sidecar() + except Exception as e: + self.fail(f'wait_for_sidecar() raised an exception unexpectedly: {e}') + + mock_get.assert_called_once() + + # Check headers are properly set + headers = mock_get.call_args[1]['headers'] + self.assertIn('User-Agent', headers) + self.assertEqual(headers['User-Agent'], f'dapr-sdk-python/{__version__}') + self.assertIn('dapr-api-token', headers) + self.assertEqual(headers['dapr-api-token'], 'mytoken') + + @patch.object(settings, 'DAPR_HEALTH_TIMEOUT', '2.5') + @patch('aiohttp.ClientSession.get') + async def test_wait_for_sidecar_timeout(self, mock_get): + # Create mock response that always returns 500 + mock_response = MagicMock() + mock_response.status = 500 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_get.return_value = mock_response + + start = time.time() + + with self.assertRaises(TimeoutError): + await DaprHealth.wait_for_sidecar() + + self.assertGreaterEqual(time.time() - start, 2.5) + self.assertGreater(mock_get.call_count, 1) + + @patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500') + @patch.object(settings, 'DAPR_HEALTH_TIMEOUT', '5.0') + @patch('aiohttp.ClientSession.get') + async def test_health_check_does_not_block(self, mock_get): + """Test that health check doesn't block other async tasks from running""" + # Mock health check to retry several times before succeeding + call_count = [0] # Use list to allow modification in nested function + + def side_effect(*args, **kwargs): + call_count[0] += 1 + # First 2 calls fail with ClientError, then succeed + # This will cause ~2 seconds of retries (1 second sleep after each failure) + if call_count[0] <= 2: + import aiohttp + + raise aiohttp.ClientError('Connection refused') + else: + mock_response = MagicMock() + mock_response.status = 200 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + return mock_response + + mock_get.side_effect = side_effect + + # Counter that will be incremented by background task + counter = [0] # Use list to allow modification in nested function + is_running = [True] + + async def increment_counter(): + """Background task that increments counter every 0.5 seconds""" + while is_running[0]: + await asyncio.sleep(0.5) + counter[0] += 1 + + # Start the background task + counter_task = asyncio.create_task(increment_counter()) + + try: + # Run health check (will take ~2 seconds with retries) + await DaprHealth.wait_for_sidecar() + + # Stop the background task + is_running[0] = False + await asyncio.sleep(0.1) # Give it time to finish current iteration + + # Verify the counter was incremented during health check + # In 2 seconds with 0.5s intervals, we expect at least 3 increments + self.assertGreaterEqual( + counter[0], + 3, + f'Expected counter to increment at least 3 times during health check, ' + f'but got {counter[0]}. This indicates health check may be blocking.', + ) + + # Verify health check made multiple attempts + self.assertGreaterEqual(call_count[0], 2) + + finally: + # Clean up + is_running[0] = False + counter_task.cancel() + try: + await counter_task + except asyncio.CancelledError: + pass + + @patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500') + @patch('aiohttp.ClientSession.get') + async def test_multiple_health_checks_concurrent(self, mock_get): + """Test that multiple health check calls can run concurrently""" + # Create mock response + mock_response = MagicMock() + mock_response.status = 200 + mock_response.__aenter__ = AsyncMock(return_value=mock_response) + mock_response.__aexit__ = AsyncMock(return_value=None) + mock_get.return_value = mock_response + + # Run multiple health checks concurrently + start_time = time.time() + results = await asyncio.gather( + DaprHealth.wait_for_sidecar(), + DaprHealth.wait_for_sidecar(), + DaprHealth.wait_for_sidecar(), + ) + elapsed = time.time() - start_time + + # All should complete successfully + self.assertEqual(len(results), 3) + self.assertIsNone(results[0]) + self.assertIsNone(results[1]) + self.assertIsNone(results[2]) + + # Should complete quickly since they run concurrently + self.assertLess(elapsed, 1.0) + + # Verify multiple calls were made + self.assertGreaterEqual(mock_get.call_count, 3) + + +if __name__ == '__main__': + unittest.main()