|
1 | 1 | """Async utils.""" |
2 | 2 |
|
3 | 3 | import asyncio |
4 | | -from typing import Any, Coroutine, List |
| 4 | +from typing import Any, Coroutine, List, Optional |
| 5 | + |
| 6 | +from tqdm.auto import tqdm |
| 7 | + |
| 8 | +from ragas.executor import is_event_loop_running |
| 9 | +from ragas.utils import batched |
5 | 10 |
|
6 | 11 |
|
7 | 12 | def run_async_tasks( |
8 | 13 | tasks: List[Coroutine], |
9 | | - show_progress: bool = False, |
| 14 | + batch_size: Optional[int] = None, |
| 15 | + show_progress: bool = True, |
10 | 16 | progress_bar_desc: str = "Running async tasks", |
11 | 17 | ) -> List[Any]: |
12 | | - """Run a list of async tasks.""" |
13 | | - tasks_to_execute: List[Any] = tasks |
| 18 | + """ |
| 19 | + Execute async tasks with optional batching and progress tracking. |
| 20 | +
|
| 21 | + NOTE: Order of results is not guaranteed! |
| 22 | +
|
| 23 | + Args: |
| 24 | + tasks: List of coroutines to execute |
| 25 | + batch_size: Optional size for batching tasks. If None, runs all concurrently |
| 26 | + show_progress: Whether to display progress bars |
| 27 | + """ |
| 28 | + |
| 29 | + async def _run(): |
| 30 | + total_tasks = len(tasks) |
| 31 | + results = [] |
14 | 32 |
|
15 | | - # if running in notebook, use nest_asyncio to hijack the event loop |
16 | | - try: |
17 | | - loop = asyncio.get_running_loop() |
| 33 | + # If no batching, run all tasks concurrently with single progress bar |
| 34 | + if not batch_size: |
| 35 | + with tqdm( |
| 36 | + total=total_tasks, |
| 37 | + desc=progress_bar_desc, |
| 38 | + disable=not show_progress, |
| 39 | + ) as pbar: |
| 40 | + for future in asyncio.as_completed(tasks): |
| 41 | + result = await future |
| 42 | + results.append(result) |
| 43 | + pbar.update(1) |
| 44 | + return results |
| 45 | + |
| 46 | + # With batching, show nested progress bars |
| 47 | + batches = batched(tasks, batch_size) # generator |
| 48 | + n_batches = (total_tasks + batch_size - 1) // batch_size |
| 49 | + with ( |
| 50 | + tqdm( |
| 51 | + total=total_tasks, |
| 52 | + desc=progress_bar_desc, |
| 53 | + disable=not show_progress, |
| 54 | + position=0, |
| 55 | + leave=True, |
| 56 | + ) as overall_pbar, |
| 57 | + tqdm( |
| 58 | + total=batch_size, |
| 59 | + desc=f"Batch 1/{n_batches}", |
| 60 | + disable=not show_progress, |
| 61 | + position=1, |
| 62 | + leave=False, |
| 63 | + ) as batch_pbar, |
| 64 | + ): |
| 65 | + for i, batch in enumerate(batches, 1): |
| 66 | + batch_pbar.reset(total=len(batch)) |
| 67 | + batch_pbar.set_description(f"Batch {i}/{n_batches}") |
| 68 | + for future in asyncio.as_completed(batch): |
| 69 | + result = await future |
| 70 | + results.append(result) |
| 71 | + overall_pbar.update(1) |
| 72 | + batch_pbar.update(1) |
| 73 | + |
| 74 | + return results |
| 75 | + |
| 76 | + if is_event_loop_running(): |
| 77 | + # an event loop is running so call nested_asyncio to fix this |
18 | 78 | try: |
19 | 79 | import nest_asyncio |
20 | 80 | except ImportError: |
21 | | - raise RuntimeError( |
22 | | - "nest_asyncio is required to run async tasks in jupyter. Please install it via `pip install nest_asyncio`." # noqa |
| 81 | + raise ImportError( |
| 82 | + "It seems like your running this in a jupyter-like environment. " |
| 83 | + "Please install nest_asyncio with `pip install nest_asyncio` to make it work." |
23 | 84 | ) |
24 | 85 | else: |
25 | 86 | nest_asyncio.apply() |
26 | | - except RuntimeError: |
27 | | - loop = asyncio.new_event_loop() |
28 | | - |
29 | | - # gather tasks to run |
30 | | - if show_progress: |
31 | | - from tqdm.asyncio import tqdm |
32 | | - |
33 | | - async def _gather() -> List[Any]: |
34 | | - "gather tasks and show progress bar" |
35 | | - return await tqdm.gather(*tasks_to_execute, desc=progress_bar_desc) |
36 | | - |
37 | | - else: # don't show_progress |
38 | | - |
39 | | - async def _gather() -> List[Any]: |
40 | | - return await asyncio.gather(*tasks_to_execute) |
41 | | - |
42 | | - try: |
43 | | - outputs: List[Any] = loop.run_until_complete(_gather()) |
44 | | - except Exception as e: |
45 | | - # run the operation w/o tqdm on hitting a fatal |
46 | | - # may occur in some environments where tqdm.asyncio |
47 | | - # is not supported |
48 | | - raise RuntimeError("Fatal error occurred while running async tasks.", e) from e |
49 | | - return outputs |
| 87 | + |
| 88 | + results = asyncio.run(_run()) |
| 89 | + return results |
0 commit comments