|
| 1 | +"""Base class for asynchronous operations. |
| 2 | +
|
| 3 | +This module provides the BaseAsyncOp class, which extends BaseOp with |
| 4 | +asynchronous execution capabilities. It supports async cache operations, |
| 5 | +async task submission and joining, and full async execution lifecycle. |
| 6 | +""" |
| 7 | + |
| 8 | +import asyncio |
| 9 | +from abc import ABCMeta, abstractmethod |
| 10 | +from typing import Any, Callable |
| 11 | + |
| 12 | +from loguru import logger |
| 13 | + |
| 14 | +from .base_op import BaseOp |
| 15 | +from ..context import FlowContext, C |
| 16 | + |
| 17 | + |
| 18 | +class BaseAsyncOp(BaseOp, metaclass=ABCMeta): |
| 19 | + """Base class for asynchronous operations. |
| 20 | +
|
| 21 | + This class extends BaseOp to provide asynchronous execution capabilities. |
| 22 | + All operations created from this class run in async mode by default. |
| 23 | +
|
| 24 | + Example: |
| 25 | + ```python |
| 26 | + class MyAsyncOp(BaseAsyncOp): |
| 27 | + async def async_execute(self): |
| 28 | + return await some_async_function() |
| 29 | +
|
| 30 | + op = MyAsyncOp() |
| 31 | + result = await op.async_call() |
| 32 | + ``` |
| 33 | + """ |
| 34 | + |
| 35 | + def __init__(self, **kwargs): |
| 36 | + """Initialize the async operation. |
| 37 | +
|
| 38 | + Automatically sets async_mode to True if not explicitly set. |
| 39 | +
|
| 40 | + Args: |
| 41 | + **kwargs: Arguments passed to BaseOp.__init__ |
| 42 | + """ |
| 43 | + kwargs.setdefault("async_mode", True) |
| 44 | + super().__init__(**kwargs) |
| 45 | + |
| 46 | + def execute(self): |
| 47 | + """Placeholder for synchronous execute method. |
| 48 | +
|
| 49 | + This method is not used in async operations. Subclasses should implement |
| 50 | + `async_execute()` instead. This method exists only to satisfy the abstract |
| 51 | + method requirement from BaseOp. |
| 52 | + """ |
| 53 | + |
| 54 | + async def async_save_load_cache(self, key: str, fn: Callable, **kwargs): |
| 55 | + """Save or load from cache asynchronously. |
| 56 | +
|
| 57 | + If caching is enabled, checks cache first. If not found, executes the |
| 58 | + function (async or sync) and saves the result. Supports both coroutine |
| 59 | + functions and regular functions (executed in thread pool). |
| 60 | +
|
| 61 | + Args: |
| 62 | + key: Cache key for storing/retrieving the result |
| 63 | + fn: Function to execute if cache miss (can be async or sync) |
| 64 | + **kwargs: Additional arguments for cache load operation |
| 65 | +
|
| 66 | + Returns: |
| 67 | + Cached result if available, otherwise result from function execution |
| 68 | + """ |
| 69 | + if self.enable_cache: |
| 70 | + result = self.cache.load(key, **kwargs) |
| 71 | + if result is None: |
| 72 | + if asyncio.iscoroutinefunction(fn): |
| 73 | + result = await fn() |
| 74 | + else: |
| 75 | + loop = asyncio.get_event_loop() |
| 76 | + result = await loop.run_in_executor(C.thread_pool, fn) |
| 77 | + |
| 78 | + self.cache.save(key, result, expire_hours=self.cache_expire_hours) |
| 79 | + else: |
| 80 | + logger.info(f"load {key} from cache") |
| 81 | + else: |
| 82 | + if asyncio.iscoroutinefunction(fn): |
| 83 | + result = await fn() |
| 84 | + else: |
| 85 | + loop = asyncio.get_event_loop() |
| 86 | + result = await loop.run_in_executor(C.thread_pool, fn) |
| 87 | + |
| 88 | + return result |
| 89 | + |
| 90 | + async def async_before_execute(self): |
| 91 | + """Hook method called before async_execute(). Override in subclasses. |
| 92 | +
|
| 93 | + This method is called automatically by `async_call()` before executing |
| 94 | + the main `async_execute()` method. Use this to perform any setup, |
| 95 | + validation, or preprocessing needed before execution. |
| 96 | +
|
| 97 | + Example: |
| 98 | + ```python |
| 99 | + async def async_before_execute(self): |
| 100 | + # Validate inputs |
| 101 | + if not self.context.get("input"): |
| 102 | + raise ValueError("Input is required") |
| 103 | + ``` |
| 104 | + """ |
| 105 | + |
| 106 | + async def async_after_execute(self): |
| 107 | + """Hook method called after async_execute(). Override in subclasses. |
| 108 | +
|
| 109 | + This method is called automatically by `async_call()` after successfully |
| 110 | + executing the main `async_execute()` method. Use this to perform any |
| 111 | + cleanup, post-processing, or result transformation. |
| 112 | +
|
| 113 | + Example: |
| 114 | + ```python |
| 115 | + async def async_after_execute(self): |
| 116 | + # Post-process results |
| 117 | + if self.context.response: |
| 118 | + self.context.response.answer = self.context.response.answer.upper() |
| 119 | + ``` |
| 120 | + """ |
| 121 | + |
| 122 | + @abstractmethod |
| 123 | + async def async_execute(self): |
| 124 | + """Main async execution method. Must be implemented in subclasses. |
| 125 | +
|
| 126 | + Returns: |
| 127 | + Execution result |
| 128 | + """ |
| 129 | + |
| 130 | + async def async_default_execute(self, e: Exception = None, **kwargs): |
| 131 | + """Default async execution method when main execution fails. Override in subclasses. |
| 132 | +
|
| 133 | + This method is called when `async_execute()` fails and `raise_exception` |
| 134 | + is False. It provides a fallback mechanism to return a default result |
| 135 | + instead of raising an exception. |
| 136 | +
|
| 137 | + Args: |
| 138 | + e: The exception that was raised during execution (if any) |
| 139 | + **kwargs: Additional keyword arguments |
| 140 | +
|
| 141 | + Returns: |
| 142 | + Default execution result |
| 143 | +
|
| 144 | + Example: |
| 145 | + ```python |
| 146 | + async def async_default_execute(self, e: Exception = None, **kwargs): |
| 147 | + logger.warning(f"Execution failed: {e}, returning default result") |
| 148 | + return {"status": "error", "message": str(e)} |
| 149 | + ``` |
| 150 | + """ |
| 151 | + |
| 152 | + async def async_call(self, context: FlowContext = None, **kwargs) -> Any: |
| 153 | + """Execute the operation asynchronously. |
| 154 | +
|
| 155 | + This method handles the full async execution lifecycle including retries, |
| 156 | + error handling, and context management. It automatically calls |
| 157 | + `async_before_execute()`, `async_execute()`, and `async_after_execute()` |
| 158 | + in sequence. |
| 159 | +
|
| 160 | + Args: |
| 161 | + context: Flow context for this execution. If None, a new context |
| 162 | + will be created. |
| 163 | + **kwargs: Additional context updates to merge into the context |
| 164 | +
|
| 165 | + Returns: |
| 166 | + Execution result from `async_execute()`, context response if result |
| 167 | + is None, or None if both are None |
| 168 | +
|
| 169 | + Raises: |
| 170 | + Exception: If execution fails and `raise_exception` is True and |
| 171 | + `max_retries` is exhausted |
| 172 | + """ |
| 173 | + self.context = self.build_context(context, **kwargs) |
| 174 | + with self.timer: |
| 175 | + result = None |
| 176 | + if self.max_retries == 1 and self.raise_exception: |
| 177 | + await self.async_before_execute() |
| 178 | + result = await self.async_execute() |
| 179 | + await self.async_after_execute() |
| 180 | + |
| 181 | + else: |
| 182 | + for i in range(self.max_retries): |
| 183 | + try: |
| 184 | + await self.async_before_execute() |
| 185 | + result = await self.async_execute() |
| 186 | + await self.async_after_execute() |
| 187 | + break |
| 188 | + |
| 189 | + except Exception as e: |
| 190 | + logger.exception(f"op={self.name} async execute failed, error={e.args}") |
| 191 | + |
| 192 | + if self.raise_exception and i == self.max_retries - 1: |
| 193 | + raise e |
| 194 | + |
| 195 | + result = await self.async_default_execute(e) |
| 196 | + |
| 197 | + if result is not None: |
| 198 | + return result |
| 199 | + |
| 200 | + elif self.context is not None and self.context.response is not None: |
| 201 | + return self.context.response |
| 202 | + |
| 203 | + else: |
| 204 | + return None |
| 205 | + |
| 206 | + def submit_async_task(self, fn: Callable, *args, **kwargs): |
| 207 | + """Submit an async task for execution. |
| 208 | +
|
| 209 | + Creates an asyncio task and adds it to the task list for later joining. |
| 210 | + Tasks can be collected using `join_async_task()`. |
| 211 | +
|
| 212 | + Args: |
| 213 | + fn: Coroutine function to execute |
| 214 | + *args: Positional arguments for the coroutine |
| 215 | + **kwargs: Keyword arguments for the coroutine |
| 216 | +
|
| 217 | + Note: |
| 218 | + Only coroutine functions are supported. Non-coroutine functions |
| 219 | + will trigger a warning and be ignored. |
| 220 | +
|
| 221 | + Example: |
| 222 | + ```python |
| 223 | + async def my_task(x): |
| 224 | + return x * 2 |
| 225 | +
|
| 226 | + self.submit_async_task(my_task, 5) |
| 227 | + results = await self.join_async_task() |
| 228 | + ``` |
| 229 | + """ |
| 230 | + loop = asyncio.get_running_loop() |
| 231 | + if asyncio.iscoroutinefunction(fn): |
| 232 | + task = loop.create_task(fn(*args, **kwargs)) |
| 233 | + self.task_list.append(task) |
| 234 | + else: |
| 235 | + logger.warning("submit_async_task failed, fn is not a coroutine function!") |
| 236 | + |
| 237 | + async def join_async_task(self, timeout: float = None, return_exceptions: bool = True): |
| 238 | + """Wait for all submitted async tasks to complete and collect results. |
| 239 | +
|
| 240 | + Collects results from all tasks, handling exceptions and timeouts. |
| 241 | + On timeout or exception, all remaining tasks are cancelled. |
| 242 | +
|
| 243 | + Args: |
| 244 | + timeout: Maximum time to wait in seconds (None for no timeout) |
| 245 | + return_exceptions: Whether to return exceptions as results |
| 246 | +
|
| 247 | + Returns: |
| 248 | + List of task results (exceptions included if return_exceptions=True) |
| 249 | +
|
| 250 | + Raises: |
| 251 | + asyncio.TimeoutError: If timeout is exceeded |
| 252 | + Exception: If any task raises an exception and return_exceptions=False |
| 253 | + """ |
| 254 | + result = [] |
| 255 | + |
| 256 | + if not self.task_list: |
| 257 | + return result |
| 258 | + |
| 259 | + try: |
| 260 | + if timeout is not None: |
| 261 | + gather_task = asyncio.gather(*self.task_list, return_exceptions=return_exceptions) |
| 262 | + task_results = await asyncio.wait_for(gather_task, timeout=timeout) |
| 263 | + else: |
| 264 | + task_results = await asyncio.gather(*self.task_list, return_exceptions=return_exceptions) |
| 265 | + |
| 266 | + for t_result in task_results: |
| 267 | + if return_exceptions and isinstance(t_result, Exception): |
| 268 | + logger.opt(exception=t_result).error("Task failed with exception") |
| 269 | + continue |
| 270 | + |
| 271 | + if t_result: |
| 272 | + if isinstance(t_result, list): |
| 273 | + result.extend(t_result) |
| 274 | + else: |
| 275 | + result.append(t_result) |
| 276 | + |
| 277 | + except asyncio.TimeoutError: |
| 278 | + logger.exception(f"join_async_task timeout after {timeout}s, cancelling {len(self.task_list)} tasks...") |
| 279 | + for task in self.task_list: |
| 280 | + if not task.done(): |
| 281 | + task.cancel() |
| 282 | + |
| 283 | + await asyncio.gather(*self.task_list, return_exceptions=True) |
| 284 | + self.task_list.clear() |
| 285 | + raise |
| 286 | + |
| 287 | + except Exception as e: |
| 288 | + logger.exception(f"join_async_task failed with {type(e).__name__}, cancelling remaining tasks...") |
| 289 | + for task in self.task_list: |
| 290 | + if not task.done(): |
| 291 | + task.cancel() |
| 292 | + |
| 293 | + await asyncio.gather(*self.task_list, return_exceptions=True) |
| 294 | + self.task_list.clear() |
| 295 | + raise |
| 296 | + |
| 297 | + finally: |
| 298 | + self.task_list.clear() |
| 299 | + |
| 300 | + return result |
0 commit comments