Skip to content
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def _managed_semaphore_execution(
semaphore_key: str,
ttl: datetime.timedelta,
execution_context: str,
expected_lock_overall_time: datetime.timedelta,
) -> AsyncIterator:
"""Common semaphore management logic with auto-renewal."""
# Acquire the semaphore first
Expand Down Expand Up @@ -106,14 +107,14 @@ async def _periodic_renewer() -> None:
finally:
lock_release_time = arrow.utcnow()
locking_time = lock_release_time - lock_acquisition_time
if locking_time > DEFAULT_EXPECTED_LOCK_OVERALL_TIME:
if locking_time > expected_lock_overall_time:
_logger.warning(
"Semaphore '%s' was held for %s which is longer than expected (%s). "
"TIP: consider reducing the locking time by optimizing the code inside "
"the critical section or increasing the default locking time",
semaphore_key,
locking_time,
DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
expected_lock_overall_time,
)


Expand Down Expand Up @@ -157,6 +158,7 @@ def with_limited_concurrency(
ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL,
blocking: bool = True,
blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT,
expected_lock_overall_time: datetime.timedelta = DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
) -> Callable[
[Callable[P, Coroutine[Any, Any, R]]], Callable[P, Coroutine[Any, Any, R]]
]:
Expand All @@ -174,6 +176,7 @@ def with_limited_concurrency(
ttl: Time-to-live for semaphore entries (default: 5 minutes)
blocking: Whether to block when semaphore is full (default: True)
blocking_timeout: Maximum time to wait when blocking (default: socket timeout)
expected_lock_overall_time: helper for logging warnings if lock is held longer than expected

Example:
@with_limited_concurrency(
Expand Down Expand Up @@ -209,7 +212,11 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
)

async with _managed_semaphore_execution(
semaphore, semaphore_key, ttl, f"coroutine_{coro.__name__}"
semaphore,
semaphore_key,
ttl,
f"coroutine_{coro.__name__}",
expected_lock_overall_time,
):
return await coro(*args, **kwargs)

Expand All @@ -226,6 +233,7 @@ def with_limited_concurrency_cm(
ttl: datetime.timedelta = DEFAULT_SEMAPHORE_TTL,
blocking: bool = True,
blocking_timeout: datetime.timedelta | None = DEFAULT_SOCKET_TIMEOUT,
expected_lock_overall_time: datetime.timedelta = DEFAULT_EXPECTED_LOCK_OVERALL_TIME,
) -> Callable[
[Callable[P, AbstractAsyncContextManager[R]]],
Callable[P, AbstractAsyncContextManager[R]],
Expand All @@ -244,6 +252,7 @@ def with_limited_concurrency_cm(
ttl: Time-to-live for semaphore entries (default: 5 minutes)
blocking: Whether to block when semaphore is full (default: True)
blocking_timeout: Maximum time to wait when blocking (default: socket timeout)
expected_lock_overall_time: helper for logging warnings if lock is held longer than expected

Example:
@asynccontextmanager
Expand Down Expand Up @@ -281,7 +290,11 @@ async def _wrapper(*args: P.args, **kwargs: P.kwargs) -> AsyncIterator[R]:

async with (
_managed_semaphore_execution(
semaphore, semaphore_key, ttl, f"context_manager_{cm_func.__name__}"
semaphore,
semaphore_key,
ttl,
f"context_manager_{cm_func.__name__}",
expected_lock_overall_time,
),
cm_func(*args, **kwargs) as value,
):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

import pytest
from pytest_mock import MockerFixture
from pytest_simcore.helpers.logging_tools import log_context
from servicelib.redis import RedisClientSDK
from servicelib.redis._constants import (
SEMAPHORE_HOLDER_KEY_PREFIX,
Expand Down Expand Up @@ -275,10 +276,10 @@ async def limited_function() -> None:
key=semaphore_name,
capacity=1,
blocking=False,
blocking_timeout=datetime.timedelta(seconds=0.1),
blocking_timeout=None,
)
async def limited_function_non_blocking() -> None:
await asyncio.sleep(0.5)
await asyncio.sleep(2)

tasks = [asyncio.create_task(limited_function_non_blocking()) for _ in range(3)]
results = await asyncio.gather(*tasks, return_exceptions=True)
Expand Down Expand Up @@ -365,11 +366,11 @@ async def test_with_large_capacity(
redis_client_sdk: RedisClientSDK,
semaphore_name: str,
):
large_capacity = 100
large_capacity = 20
concurrent_count = 0
max_concurrent = 0
sleep_time_s = 5
num_tasks = 1000
sleep_time_s = 10
num_tasks = 500

@with_limited_concurrency(
redis_client_sdk,
Expand All @@ -382,9 +383,8 @@ async def limited_function() -> None:
nonlocal concurrent_count, max_concurrent
concurrent_count += 1
max_concurrent = max(max_concurrent, concurrent_count)
logging.info("Started task, current concurrent: %d", concurrent_count)
await asyncio.sleep(sleep_time_s)
logging.info("Done task, current concurrent: %d", concurrent_count)
with log_context(logging.INFO, f"task with {concurrent_count=}"):
await asyncio.sleep(sleep_time_s)
concurrent_count -= 1

# Start tasks equal to the large capacity
Expand All @@ -400,6 +400,63 @@ async def limited_function() -> None:
assert max_concurrent <= large_capacity


async def test_long_locking_logs_warning(
redis_client_sdk: RedisClientSDK,
semaphore_name: str,
caplog: pytest.LogCaptureFixture,
mocker: MockerFixture,
):
@with_limited_concurrency(
redis_client_sdk,
key=semaphore_name,
capacity=1,
blocking=True,
blocking_timeout=None,
expected_lock_overall_time=datetime.timedelta(milliseconds=200),
)
async def limited_function() -> None:
with log_context(logging.INFO, "task"):
await asyncio.sleep(0.4)

with caplog.at_level(logging.WARNING):
await limited_function()
assert caplog.records
assert "longer than expected" in caplog.messages[-1]


@pytest.mark.skip
async def test_semaphore_fair_queuing(
redis_client_sdk: RedisClientSDK,
semaphore_name: str,
):
entered_order: list[int] = []

@with_limited_concurrency(
redis_client_sdk,
key=semaphore_name,
capacity=1,
)
async def limited_function(call_id: int):
entered_order.append(call_id)
await asyncio.sleep(0.1)
return call_id

# Launch tasks in a specific order
num_tasks = 10
tasks = []
for i in range(num_tasks):
tasks.append(asyncio.create_task(limited_function(i)))
await asyncio.sleep(0.01) # Small delay to help preserve order
results = await asyncio.gather(*tasks)

# All should complete successfully and in order
assert results == list(range(num_tasks))
# The order in which they entered the critical section should match the order of submission
assert entered_order == list(
range(num_tasks)
), f"Expected fair queuing, got {entered_order}"


async def test_context_manager_basic_functionality(
redis_client_sdk: RedisClientSDK,
semaphore_name: str,
Expand Down Expand Up @@ -442,6 +499,7 @@ async def test_context_manager_capacity_enforcement(
redis_client_sdk,
key=semaphore_name,
capacity=2,
blocking_timeout=None,
)
@asynccontextmanager
async def limited_context_manager():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,13 @@ async def list_unrunnable_tasks(self, app: FastAPI) -> list[DaskTask]:

def get_task_required_resources(self, task) -> Resources:
assert self # nosec
return utils.resources_from_dask_task(task)
task_required_resources = utils.resources_from_dask_task(task)
# ensure cpu is set at least to 1 as dask-workers use 1 thread per CPU
if task_required_resources.cpus < 1.0:
task_required_resources = task_required_resources.model_copy(
update={"cpus": 1.0}
)
return task_required_resources

async def get_task_defined_instance(
self, app: FastAPI, task
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from typing import Final, cast

import distributed
import distributed.client
from aiohttp import ClientResponseError
from common_library.json_serialization import json_dumps
from common_library.logging.logging_errors import create_troubleshooting_log_kwargs
Expand Down Expand Up @@ -68,6 +69,7 @@

from ..core.errors import (
ComputationalBackendNoS3AccessError,
ComputationalBackendNotConnectedError,
ComputationalBackendTaskNotFoundError,
ComputationalBackendTaskResultsNotReadyError,
TaskSchedulingError,
Expand All @@ -92,7 +94,7 @@


_UserCallbackInSepThread = Callable[[], None]
_MAX_CONCURRENT_CLIENT_CONNECTIONS: Final[int] = 10
_MAX_CONCURRENT_CLIENT_CONNECTIONS: Final[int] = 1


@dataclass(frozen=True, kw_only=True, slots=True)
Expand Down Expand Up @@ -552,6 +554,11 @@ async def get_task_result(self, job_id: str) -> TaskOutputData:
raise ComputationalBackendTaskNotFoundError(job_id=job_id) from exc
except distributed.TimeoutError as exc:
raise ComputationalBackendTaskResultsNotReadyError(job_id=job_id) from exc
except (
distributed.client.FutureCancelledError,
distributed.client.FuturesCancelledError,
) as exc:
raise ComputationalBackendNotConnectedError from exc

async def release_task_result(self, job_id: str) -> None:
_logger.debug("releasing results for %s", f"{job_id=}")
Expand Down
Loading