Skip to content

Commit 7951255

Browse files
committed
feat(core): add base operation classes and token counting implementations
1 parent 1c8b220 commit 7951255

File tree

11 files changed

+1093
-4
lines changed

11 files changed

+1093
-4
lines changed

reme_ai/core/llm/lite_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ def _build_stream_kwargs(
111111
**kwargs: Additional parameters to pass to the API (e.g., temperature, max_tokens)
112112
113113
Returns:
114-
Dictionary of parameters ready for LiteLLM API call
114+
Dictionary of parameters for LiteLLM API call
115115
"""
116116
# Construct the API parameters by merging multiple sources
117117
llm_kwargs = {

reme_ai/core/op/__init__.py

Whitespace-only changes.

reme_ai/core/op/base_async_op.py

Lines changed: 300 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,300 @@
1+
"""Base class for asynchronous operations.
2+
3+
This module provides the BaseAsyncOp class, which extends BaseOp with
4+
asynchronous execution capabilities. It supports async cache operations,
5+
async task submission and joining, and full async execution lifecycle.
6+
"""
7+
8+
import asyncio
9+
from abc import ABCMeta, abstractmethod
10+
from typing import Any, Callable
11+
12+
from loguru import logger
13+
14+
from .base_op import BaseOp
15+
from ..context import FlowContext, C
16+
17+
18+
class BaseAsyncOp(BaseOp, metaclass=ABCMeta):
19+
"""Base class for asynchronous operations.
20+
21+
This class extends BaseOp to provide asynchronous execution capabilities.
22+
All operations created from this class run in async mode by default.
23+
24+
Example:
25+
```python
26+
class MyAsyncOp(BaseAsyncOp):
27+
async def async_execute(self):
28+
return await some_async_function()
29+
30+
op = MyAsyncOp()
31+
result = await op.async_call()
32+
```
33+
"""
34+
35+
def __init__(self, **kwargs):
36+
"""Initialize the async operation.
37+
38+
Automatically sets async_mode to True if not explicitly set.
39+
40+
Args:
41+
**kwargs: Arguments passed to BaseOp.__init__
42+
"""
43+
kwargs.setdefault("async_mode", True)
44+
super().__init__(**kwargs)
45+
46+
def execute(self):
47+
"""Placeholder for synchronous execute method.
48+
49+
This method is not used in async operations. Subclasses should implement
50+
`async_execute()` instead. This method exists only to satisfy the abstract
51+
method requirement from BaseOp.
52+
"""
53+
54+
async def async_save_load_cache(self, key: str, fn: Callable, **kwargs):
55+
"""Save or load from cache asynchronously.
56+
57+
If caching is enabled, checks cache first. If not found, executes the
58+
function (async or sync) and saves the result. Supports both coroutine
59+
functions and regular functions (executed in thread pool).
60+
61+
Args:
62+
key: Cache key for storing/retrieving the result
63+
fn: Function to execute if cache miss (can be async or sync)
64+
**kwargs: Additional arguments for cache load operation
65+
66+
Returns:
67+
Cached result if available, otherwise result from function execution
68+
"""
69+
if self.enable_cache:
70+
result = self.cache.load(key, **kwargs)
71+
if result is None:
72+
if asyncio.iscoroutinefunction(fn):
73+
result = await fn()
74+
else:
75+
loop = asyncio.get_event_loop()
76+
result = await loop.run_in_executor(C.thread_pool, fn)
77+
78+
self.cache.save(key, result, expire_hours=self.cache_expire_hours)
79+
else:
80+
logger.info(f"load {key} from cache")
81+
else:
82+
if asyncio.iscoroutinefunction(fn):
83+
result = await fn()
84+
else:
85+
loop = asyncio.get_event_loop()
86+
result = await loop.run_in_executor(C.thread_pool, fn)
87+
88+
return result
89+
90+
async def async_before_execute(self):
91+
"""Hook method called before async_execute(). Override in subclasses.
92+
93+
This method is called automatically by `async_call()` before executing
94+
the main `async_execute()` method. Use this to perform any setup,
95+
validation, or preprocessing needed before execution.
96+
97+
Example:
98+
```python
99+
async def async_before_execute(self):
100+
# Validate inputs
101+
if not self.context.get("input"):
102+
raise ValueError("Input is required")
103+
```
104+
"""
105+
106+
async def async_after_execute(self):
107+
"""Hook method called after async_execute(). Override in subclasses.
108+
109+
This method is called automatically by `async_call()` after successfully
110+
executing the main `async_execute()` method. Use this to perform any
111+
cleanup, post-processing, or result transformation.
112+
113+
Example:
114+
```python
115+
async def async_after_execute(self):
116+
# Post-process results
117+
if self.context.response:
118+
self.context.response.answer = self.context.response.answer.upper()
119+
```
120+
"""
121+
122+
@abstractmethod
123+
async def async_execute(self):
124+
"""Main async execution method. Must be implemented in subclasses.
125+
126+
Returns:
127+
Execution result
128+
"""
129+
130+
async def async_default_execute(self, e: Exception = None, **kwargs):
131+
"""Default async execution method when main execution fails. Override in subclasses.
132+
133+
This method is called when `async_execute()` fails and `raise_exception`
134+
is False. It provides a fallback mechanism to return a default result
135+
instead of raising an exception.
136+
137+
Args:
138+
e: The exception that was raised during execution (if any)
139+
**kwargs: Additional keyword arguments
140+
141+
Returns:
142+
Default execution result
143+
144+
Example:
145+
```python
146+
async def async_default_execute(self, e: Exception = None, **kwargs):
147+
logger.warning(f"Execution failed: {e}, returning default result")
148+
return {"status": "error", "message": str(e)}
149+
```
150+
"""
151+
152+
async def async_call(self, context: FlowContext = None, **kwargs) -> Any:
153+
"""Execute the operation asynchronously.
154+
155+
This method handles the full async execution lifecycle including retries,
156+
error handling, and context management. It automatically calls
157+
`async_before_execute()`, `async_execute()`, and `async_after_execute()`
158+
in sequence.
159+
160+
Args:
161+
context: Flow context for this execution. If None, a new context
162+
will be created.
163+
**kwargs: Additional context updates to merge into the context
164+
165+
Returns:
166+
Execution result from `async_execute()`, context response if result
167+
is None, or None if both are None
168+
169+
Raises:
170+
Exception: If execution fails and `raise_exception` is True and
171+
`max_retries` is exhausted
172+
"""
173+
self.context = self.build_context(context, **kwargs)
174+
with self.timer:
175+
result = None
176+
if self.max_retries == 1 and self.raise_exception:
177+
await self.async_before_execute()
178+
result = await self.async_execute()
179+
await self.async_after_execute()
180+
181+
else:
182+
for i in range(self.max_retries):
183+
try:
184+
await self.async_before_execute()
185+
result = await self.async_execute()
186+
await self.async_after_execute()
187+
break
188+
189+
except Exception as e:
190+
logger.exception(f"op={self.name} async execute failed, error={e.args}")
191+
192+
if self.raise_exception and i == self.max_retries - 1:
193+
raise e
194+
195+
result = await self.async_default_execute(e)
196+
197+
if result is not None:
198+
return result
199+
200+
elif self.context is not None and self.context.response is not None:
201+
return self.context.response
202+
203+
else:
204+
return None
205+
206+
def submit_async_task(self, fn: Callable, *args, **kwargs):
207+
"""Submit an async task for execution.
208+
209+
Creates an asyncio task and adds it to the task list for later joining.
210+
Tasks can be collected using `join_async_task()`.
211+
212+
Args:
213+
fn: Coroutine function to execute
214+
*args: Positional arguments for the coroutine
215+
**kwargs: Keyword arguments for the coroutine
216+
217+
Note:
218+
Only coroutine functions are supported. Non-coroutine functions
219+
will trigger a warning and be ignored.
220+
221+
Example:
222+
```python
223+
async def my_task(x):
224+
return x * 2
225+
226+
self.submit_async_task(my_task, 5)
227+
results = await self.join_async_task()
228+
```
229+
"""
230+
loop = asyncio.get_running_loop()
231+
if asyncio.iscoroutinefunction(fn):
232+
task = loop.create_task(fn(*args, **kwargs))
233+
self.task_list.append(task)
234+
else:
235+
logger.warning("submit_async_task failed, fn is not a coroutine function!")
236+
237+
async def join_async_task(self, timeout: float = None, return_exceptions: bool = True):
238+
"""Wait for all submitted async tasks to complete and collect results.
239+
240+
Collects results from all tasks, handling exceptions and timeouts.
241+
On timeout or exception, all remaining tasks are cancelled.
242+
243+
Args:
244+
timeout: Maximum time to wait in seconds (None for no timeout)
245+
return_exceptions: Whether to return exceptions as results
246+
247+
Returns:
248+
List of task results (exceptions included if return_exceptions=True)
249+
250+
Raises:
251+
asyncio.TimeoutError: If timeout is exceeded
252+
Exception: If any task raises an exception and return_exceptions=False
253+
"""
254+
result = []
255+
256+
if not self.task_list:
257+
return result
258+
259+
try:
260+
if timeout is not None:
261+
gather_task = asyncio.gather(*self.task_list, return_exceptions=return_exceptions)
262+
task_results = await asyncio.wait_for(gather_task, timeout=timeout)
263+
else:
264+
task_results = await asyncio.gather(*self.task_list, return_exceptions=return_exceptions)
265+
266+
for t_result in task_results:
267+
if return_exceptions and isinstance(t_result, Exception):
268+
logger.opt(exception=t_result).error("Task failed with exception")
269+
continue
270+
271+
if t_result:
272+
if isinstance(t_result, list):
273+
result.extend(t_result)
274+
else:
275+
result.append(t_result)
276+
277+
except asyncio.TimeoutError:
278+
logger.exception(f"join_async_task timeout after {timeout}s, cancelling {len(self.task_list)} tasks...")
279+
for task in self.task_list:
280+
if not task.done():
281+
task.cancel()
282+
283+
await asyncio.gather(*self.task_list, return_exceptions=True)
284+
self.task_list.clear()
285+
raise
286+
287+
except Exception as e:
288+
logger.exception(f"join_async_task failed with {type(e).__name__}, cancelling remaining tasks...")
289+
for task in self.task_list:
290+
if not task.done():
291+
task.cancel()
292+
293+
await asyncio.gather(*self.task_list, return_exceptions=True)
294+
self.task_list.clear()
295+
raise
296+
297+
finally:
298+
self.task_list.clear()
299+
300+
return result

0 commit comments

Comments
 (0)