Skip to content
Open
4 changes: 2 additions & 2 deletions dapr/aio/clients/grpc/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
UnlockResponseStatus,
)
from dapr.clients.grpc._state import StateItem, StateOptions
from dapr.clients.health import DaprHealth
from dapr.aio.clients.health import DaprHealth
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for my previous comment, but since we're calling the health check from the __init__, we can't make it asynchronously because __init__ doesn't support async. So I'm afraid we need to stick with the synchronous health check here unfortunately.
It'd be good to have some sort of .connect method that we can call outside of the __init__ and would support asynchronous calls, but there's no such a thing right now, and I don't think we want to make a big change for this...
So, can you bring back the sync call please?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got it—thanks for the clarification! I've reverted to the synchronous health check.

from dapr.clients.retry import RetryPolicy
from dapr.common.pubsub.subscription import StreamInactiveError
from dapr.conf import settings
Expand Down Expand Up @@ -1906,7 +1906,7 @@ async def wait(self, timeout_s: float):
remaining = (start + timeout_s) - time.time()
if remaining < 0:
raise e
asyncio.sleep(min(1, remaining))
await asyncio.sleep(min(1, remaining))

async def get_metadata(self) -> GetMetadataResponse:
"""Returns information about the sidecar allowing for runtime
Expand Down
6 changes: 3 additions & 3 deletions dapr/aio/clients/grpc/subscription.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from grpc.aio import AioRpcError

from dapr.clients.grpc._response import TopicEventResponse
from dapr.clients.health import DaprHealth
from dapr.aio.clients.health import DaprHealth
from dapr.common.pubsub.subscription import (
StreamCancelledError,
StreamInactiveError,
Expand Down Expand Up @@ -52,7 +52,7 @@ async def outgoing_request_iterator():

async def reconnect_stream(self):
await self.close()
DaprHealth.wait_for_sidecar()
await DaprHealth.wait_for_sidecar()
print('Attempting to reconnect...')
await self.start()

Expand All @@ -67,7 +67,7 @@ async def next_message(self):
return None
return SubscriptionMessage(message.event_message)
except AioRpcError as e:
if e.code() == StatusCode.UNAVAILABLE:
if e.code() == StatusCode.UNAVAILABLE or e.code() == StatusCode.UNKNOWN:
print(
f'gRPC error while reading from stream: {e.details()}, '
f'Status Code: {e.code()}. '
Expand Down
59 changes: 59 additions & 0 deletions dapr/aio/clients/health.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-

"""
Copyright 2024 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import time

import aiohttp

from dapr.clients.http.conf import DAPR_API_TOKEN_HEADER, DAPR_USER_AGENT, USER_AGENT_HEADER
from dapr.clients.http.helpers import get_api_url
from dapr.conf import settings


class DaprHealth:
@staticmethod
async def wait_for_sidecar():
health_url = f'{get_api_url()}/healthz/outbound'
headers = {USER_AGENT_HEADER: DAPR_USER_AGENT}
if settings.DAPR_API_TOKEN is not None:
headers[DAPR_API_TOKEN_HEADER] = settings.DAPR_API_TOKEN
timeout = float(settings.DAPR_HEALTH_TIMEOUT)

start = time.time()
ssl_context = DaprHealth.get_ssl_context()

connector = aiohttp.TCPConnector(ssl=ssl_context)
async with aiohttp.ClientSession(connector=connector) as session:
while True:
try:
async with session.get(health_url, headers=headers) as response:
if 200 <= response.status < 300:
break
except aiohttp.ClientError as e:
print(f'Health check on {health_url} failed: {e}')
except Exception as e:
print(f'Unexpected error during health check: {e}')

remaining = (start + timeout) - time.time()
if remaining <= 0:
raise TimeoutError(f'Dapr health check timed out, after {timeout}.')
await asyncio.sleep(min(1, remaining))

@staticmethod
def get_ssl_context():
# This method is used (overwritten) from tests
# to return context for self-signed certificates
return None
4 changes: 2 additions & 2 deletions tests/clients/test_dapr_grpc_client_async_secure.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
from unittest.mock import patch

from dapr.aio.clients.grpc.client import DaprGrpcClientAsync
from dapr.clients.health import DaprHealth
from dapr.aio.clients.health import DaprHealth as DaprHealthAsync
from dapr.conf import settings
from tests.clients.certs import replacement_get_credentials_func, replacement_get_health_context
from tests.clients.test_dapr_grpc_client_async import DaprGrpcClientAsyncTests

from .fake_dapr_server import FakeDaprSidecar

DaprGrpcClientAsync.get_credentials = replacement_get_credentials_func
DaprHealth.get_ssl_context = replacement_get_health_context
DaprHealthAsync.get_ssl_context = replacement_get_health_context


class DaprSecureGrpcClientAsyncTests(DaprGrpcClientAsyncTests):
Expand Down
File renamed without changes.
197 changes: 197 additions & 0 deletions tests/clients/test_healthcheck_async.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
# -*- coding: utf-8 -*-

"""
Copyright 2025 The Dapr Authors
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import asyncio
import time
import unittest
from unittest.mock import AsyncMock, MagicMock, patch

from dapr.aio.clients.health import DaprHealth
from dapr.conf import settings
from dapr.version import __version__


class DaprHealthCheckAsyncTests(unittest.IsolatedAsyncioTestCase):
@patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500')
@patch('aiohttp.ClientSession.get')
async def test_wait_for_sidecar_success(self, mock_get):
# Create mock response
mock_response = MagicMock()
mock_response.status = 200
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
mock_get.return_value = mock_response

try:
await DaprHealth.wait_for_sidecar()
except Exception as e:
self.fail(f'wait_for_sidecar() raised an exception unexpectedly: {e}')

mock_get.assert_called_once()

# Check URL
called_url = mock_get.call_args[0][0]
self.assertEqual(called_url, 'http://domain.com:3500/v1.0/healthz/outbound')

# Check headers are properly set
headers = mock_get.call_args[1]['headers']
self.assertIn('User-Agent', headers)
self.assertEqual(headers['User-Agent'], f'dapr-sdk-python/{__version__}')

@patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500')
@patch.object(settings, 'DAPR_API_TOKEN', 'mytoken')
@patch('aiohttp.ClientSession.get')
async def test_wait_for_sidecar_success_with_api_token(self, mock_get):
# Create mock response
mock_response = MagicMock()
mock_response.status = 200
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
mock_get.return_value = mock_response

try:
await DaprHealth.wait_for_sidecar()
except Exception as e:
self.fail(f'wait_for_sidecar() raised an exception unexpectedly: {e}')

mock_get.assert_called_once()

# Check headers are properly set
headers = mock_get.call_args[1]['headers']
self.assertIn('User-Agent', headers)
self.assertEqual(headers['User-Agent'], f'dapr-sdk-python/{__version__}')
self.assertIn('dapr-api-token', headers)
self.assertEqual(headers['dapr-api-token'], 'mytoken')

@patch.object(settings, 'DAPR_HEALTH_TIMEOUT', '2.5')
@patch('aiohttp.ClientSession.get')
async def test_wait_for_sidecar_timeout(self, mock_get):
# Create mock response that always returns 500
mock_response = MagicMock()
mock_response.status = 500
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
mock_get.return_value = mock_response

start = time.time()

with self.assertRaises(TimeoutError):
await DaprHealth.wait_for_sidecar()

self.assertGreaterEqual(time.time() - start, 2.5)
self.assertGreater(mock_get.call_count, 1)

@patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500')
@patch.object(settings, 'DAPR_HEALTH_TIMEOUT', '5.0')
@patch('aiohttp.ClientSession.get')
async def test_health_check_does_not_block(self, mock_get):
"""Test that health check doesn't block other async tasks from running"""
# Mock health check to retry several times before succeeding
call_count = [0] # Use list to allow modification in nested function

def side_effect(*args, **kwargs):
call_count[0] += 1
# First 2 calls fail with ClientError, then succeed
# This will cause ~2 seconds of retries (1 second sleep after each failure)
if call_count[0] <= 2:
import aiohttp

raise aiohttp.ClientError('Connection refused')
else:
mock_response = MagicMock()
mock_response.status = 200
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
return mock_response

mock_get.side_effect = side_effect

# Counter that will be incremented by background task
counter = [0] # Use list to allow modification in nested function
is_running = [True]

async def increment_counter():
"""Background task that increments counter every 0.5 seconds"""
while is_running[0]:
await asyncio.sleep(0.5)
counter[0] += 1

# Start the background task
counter_task = asyncio.create_task(increment_counter())

try:
# Run health check (will take ~2 seconds with retries)
await DaprHealth.wait_for_sidecar()

# Stop the background task
is_running[0] = False
await asyncio.sleep(0.1) # Give it time to finish current iteration

# Verify the counter was incremented during health check
# In 2 seconds with 0.5s intervals, we expect at least 3 increments
self.assertGreaterEqual(
counter[0],
3,
f'Expected counter to increment at least 3 times during health check, '
f'but got {counter[0]}. This indicates health check may be blocking.',
)

# Verify health check made multiple attempts
self.assertGreaterEqual(call_count[0], 2)

finally:
# Clean up
is_running[0] = False
counter_task.cancel()
try:
await counter_task
except asyncio.CancelledError:
pass

@patch.object(settings, 'DAPR_HTTP_ENDPOINT', 'http://domain.com:3500')
@patch('aiohttp.ClientSession.get')
async def test_multiple_health_checks_concurrent(self, mock_get):
"""Test that multiple health check calls can run concurrently"""
# Create mock response
mock_response = MagicMock()
mock_response.status = 200
mock_response.__aenter__ = AsyncMock(return_value=mock_response)
mock_response.__aexit__ = AsyncMock(return_value=None)
mock_get.return_value = mock_response

# Run multiple health checks concurrently
start_time = time.time()
results = await asyncio.gather(
DaprHealth.wait_for_sidecar(),
DaprHealth.wait_for_sidecar(),
DaprHealth.wait_for_sidecar(),
)
elapsed = time.time() - start_time

# All should complete successfully
self.assertEqual(len(results), 3)
self.assertIsNone(results[0])
self.assertIsNone(results[1])
self.assertIsNone(results[2])

# Should complete quickly since they run concurrently
self.assertLess(elapsed, 1.0)

# Verify multiple calls were made
self.assertGreaterEqual(mock_get.call_count, 3)


if __name__ == '__main__':
unittest.main()