Skip to content

Commit b84b85f

Browse files
authored
🎨 Maintenance: Add new concurrency tooling (#5997)
1 parent 42cc5e6 commit b84b85f

File tree

7 files changed

+342
-43
lines changed

7 files changed

+342
-43
lines changed

packages/service-library/src/servicelib/utils.py

Lines changed: 147 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,24 @@
44
I order to avoid cyclic dependences, please
55
DO NOT IMPORT ANYTHING from .
66
"""
7+
78
import asyncio
89
import logging
910
import os
1011
import socket
1112
from collections.abc import Awaitable, Coroutine, Generator, Iterable
1213
from pathlib import Path
13-
from typing import Any, Final, cast
14+
from typing import Any, AsyncGenerator, AsyncIterable, Final, TypeVar, cast
1415

1516
import toolz
1617
from pydantic import NonNegativeInt
1718

1819
_logger = logging.getLogger(__name__)
1920

21+
_DEFAULT_GATHER_TASKS_GROUP_PREFIX: Final[str] = "gathered"
22+
_DEFAULT_LOGGER: Final[logging.Logger] = _logger
23+
_DEFAULT_LIMITED_CONCURRENCY: Final[int] = 1
24+
2025

2126
def is_production_environ() -> bool:
2227
"""
@@ -175,3 +180,144 @@ def unused_port() -> int:
175180
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
176181
s.bind(("127.0.0.1", 0))
177182
return cast(int, s.getsockname()[1])
183+
184+
185+
T = TypeVar("T")
186+
187+
188+
async def limited_as_completed(
189+
awaitables: Iterable[Awaitable[T]] | AsyncIterable[Awaitable[T]],
190+
*,
191+
limit: int = _DEFAULT_LIMITED_CONCURRENCY,
192+
tasks_group_prefix: str | None = None,
193+
) -> AsyncGenerator[asyncio.Future[T], None]:
194+
"""Runs awaitables using limited concurrent tasks and returns
195+
result futures unordered.
196+
197+
Arguments:
198+
awaitables -- The awaitables to limit the concurrency of.
199+
200+
Keyword Arguments:
201+
limit -- The maximum number of awaitables to run concurrently.
202+
0 or negative values disables the limit. (default: {1})
203+
tasks_group_prefix -- The prefix to use for the name of the asyncio tasks group.
204+
If None, no name is used. (default: {None})
205+
206+
Returns:
207+
nothing
208+
209+
Yields:
210+
Future[T]: the future of the awaitables as they appear.
211+
212+
213+
"""
214+
try:
215+
awaitable_iterator = aiter(awaitables) # type: ignore[arg-type]
216+
is_async = True
217+
except TypeError:
218+
assert isinstance(awaitables, Iterable) # nosec
219+
awaitable_iterator = iter(awaitables) # type: ignore[assignment]
220+
is_async = False
221+
222+
completed_all_awaitables = False
223+
pending_futures: set[asyncio.Future] = set()
224+
225+
try:
226+
while pending_futures or not completed_all_awaitables:
227+
while (
228+
limit < 1 or len(pending_futures) < limit
229+
) and not completed_all_awaitables:
230+
try:
231+
aw = (
232+
await anext(awaitable_iterator)
233+
if is_async
234+
else next(awaitable_iterator) # type: ignore[call-overload]
235+
)
236+
future = asyncio.ensure_future(aw)
237+
if tasks_group_prefix:
238+
future.set_name(f"{tasks_group_prefix}-{future.get_name()}")
239+
pending_futures.add(future)
240+
except (StopIteration, StopAsyncIteration): # noqa: PERF203
241+
completed_all_awaitables = True
242+
if not pending_futures:
243+
return
244+
done, pending_futures = await asyncio.wait(
245+
pending_futures, return_when=asyncio.FIRST_COMPLETED
246+
)
247+
248+
for future in done:
249+
yield future
250+
except asyncio.CancelledError:
251+
for future in pending_futures:
252+
future.cancel()
253+
await asyncio.gather(*pending_futures, return_exceptions=True)
254+
raise
255+
256+
257+
async def _wrapped(
258+
awaitable: Awaitable[T], *, index: int, reraise: bool, logger: logging.Logger
259+
) -> tuple[int, T | BaseException]:
260+
try:
261+
return index, await awaitable
262+
except asyncio.CancelledError:
263+
logger.debug(
264+
"Cancelled %i-th concurrent task %s",
265+
index + 1,
266+
f"{awaitable=}",
267+
)
268+
raise
269+
except BaseException as exc: # pylint: disable=broad-exception-caught
270+
logger.warning(
271+
"Error in %i-th concurrent task %s: %s",
272+
index + 1,
273+
f"{awaitable=}",
274+
f"{exc=}",
275+
)
276+
if reraise:
277+
raise
278+
return index, exc
279+
280+
281+
async def limited_gather(
282+
*awaitables: Awaitable[T],
283+
reraise: bool = True,
284+
log: logging.Logger = _DEFAULT_LOGGER,
285+
limit: int = _DEFAULT_LIMITED_CONCURRENCY,
286+
tasks_group_prefix: str | None = None,
287+
) -> list[T | BaseException | None]:
288+
"""runs all the awaitables using the limited concurrency and returns them in the same order
289+
290+
Arguments:
291+
awaitables -- The awaitables to limit the concurrency of.
292+
293+
Keyword Arguments:
294+
limit -- The maximum number of awaitables to run concurrently.
295+
setting 0 or negative values disable (default: {1})
296+
reraise -- if True will raise at the first exception
297+
The remaining tasks will continue as in standard asyncio gather.
298+
If False, then the exceptions will be returned (default: {True})
299+
log -- the logger to use for logging the exceptions (default: {_logger})
300+
tasks_group_prefix -- The prefix to use for the name of the asyncio tasks group.
301+
If None, 'gathered' prefix is used. (default: {None})
302+
303+
Returns:
304+
the results of the awaitables keeping the order
305+
306+
special thanks to: https://death.andgravity.com/limit-concurrency
307+
"""
308+
309+
indexed_awaitables = [
310+
_wrapped(awaitable, reraise=reraise, index=index, logger=log)
311+
for index, awaitable in enumerate(awaitables)
312+
]
313+
314+
results: list[T | BaseException | None] = [None] * len(indexed_awaitables)
315+
async for future in limited_as_completed(
316+
indexed_awaitables,
317+
limit=limit,
318+
tasks_group_prefix=tasks_group_prefix or _DEFAULT_GATHER_TASKS_GROUP_PREFIX,
319+
):
320+
index, result = await future
321+
results[index] = result
322+
323+
return results

packages/service-library/tests/test_archiving_utils.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from concurrent.futures import ProcessPoolExecutor
1515
from dataclasses import dataclass
1616
from pathlib import Path
17-
from typing import Callable, Iterable, Iterator, Optional
17+
from typing import Callable, Iterable, Iterator
1818

1919
import pytest
2020
from faker import Faker
@@ -23,7 +23,12 @@
2323
from servicelib import archiving_utils
2424
from servicelib.archiving_utils import ArchiveError, archive_dir, unarchive_dir
2525

26-
from .test_utils import print_tree
26+
27+
def _print_tree(path: Path, level=0):
28+
tab = " " * level
29+
print(f"{tab}{'+' if path.is_dir() else '-'} {path if level==0 else path.name}")
30+
for p in path.glob("*"):
31+
_print_tree(p, level + 1)
2732

2833

2934
@pytest.fixture
@@ -96,7 +101,7 @@ def exclude_patterns_validation_dir(tmp_path: Path, faker: Faker) -> Path:
96101
(base_dir / "d1" / "sd1" / "f2.txt").write_text(faker.text())
97102

98103
print("exclude_patterns_validation_dir ---")
99-
print_tree(base_dir)
104+
_print_tree(base_dir)
100105
return base_dir
101106

102107

@@ -174,7 +179,7 @@ def _escape_undecodable_path(path: Path) -> Path:
174179
async def assert_same_directory_content(
175180
dir_to_compress: Path,
176181
output_dir: Path,
177-
inject_relative_path: Optional[Path] = None,
182+
inject_relative_path: Path | None = None,
178183
unsupported_replace: bool = False,
179184
) -> None:
180185
def _relative_path(input_path: Path) -> Path:

packages/service-library/tests/test_archiving_utils_extra.py

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@
1313
unarchive_dir,
1414
)
1515

16-
from .test_utils import print_tree
16+
17+
def _print_tree(path: Path, level=0):
18+
tab = " " * level
19+
print(f"{tab}{'+' if path.is_dir() else '-'} {path if level==0 else path.name}")
20+
for p in path.glob("*"):
21+
_print_tree(p, level + 1)
1722

1823

1924
@pytest.fixture
@@ -32,7 +37,7 @@ def state_dir(tmp_path) -> Path:
3237
(base_dir / "d1" / "d1_1" / "d1_1_1" / "f6").touch()
3338

3439
print("state-dir ---")
35-
print_tree(base_dir)
40+
_print_tree(base_dir)
3641
# + /tmp/pytest-of-crespo/pytest-95/test_override_and_prune_from_a1/original
3742
# + empty
3843
# + d1
@@ -64,7 +69,7 @@ def new_state_dir(tmp_path) -> Path:
6469
# f6 deleted -> d1/d1_1/d2_2 remains empty and should be pruned
6570

6671
print("new-state-dir ---")
67-
print_tree(base_dir)
72+
_print_tree(base_dir)
6873
# + /tmp/pytest-of-crespo/pytest-95/test_override_and_prune_from_a1/updated
6974
# + d1
7075
# + d1_1
@@ -120,7 +125,7 @@ def test_override_and_prune_folder(state_dir: Path, new_state_dir: Path):
120125
assert old_paths != got_paths
121126

122127
print("after ----")
123-
print_tree(state_dir)
128+
_print_tree(state_dir)
124129

125130

126131
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)