Skip to content

Commit 9a88669

Browse files
committed
feat: add async generative slots
1 parent b956f8d commit 9a88669

File tree

2 files changed

+110
-6
lines changed

2 files changed

+110
-6
lines changed

mellea/stdlib/genslot.py

Lines changed: 92 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
"""A method to generate outputs based on python functions and a Generative Slot function."""
22

3+
import asyncio
34
import functools
45
import inspect
5-
from collections.abc import Callable
6+
from collections.abc import Callable, Coroutine
67
from copy import deepcopy
78
from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints
89

@@ -168,14 +169,13 @@ def __call__(
168169
**kwargs: Additional Kwargs to be passed to the func.
169170
170171
Returns:
171-
ModelOutputThunk: Output with generated Thunk.
172+
R: an object with the original return type of the function
172173
"""
173174
if m is None:
174175
m = get_session()
175176
slot_copy = deepcopy(self)
176177
arguments = bind_function_arguments(self._function._func, *args, **kwargs)
177178
if arguments:
178-
# slot_copy._arguments = []
179179
for key, val in arguments.items():
180180
annotation = get_annotation(slot_copy._function._func, key, val)
181181
slot_copy._arguments.append(Argument(annotation, key, val))
@@ -207,6 +207,52 @@ def format_for_llm(self) -> TemplateRepresentation:
207207
)
208208

209209

210+
class AsyncGenerativeSlot(GenerativeSlot, Generic[P, R]):
211+
"""A generative slot component that generates asynchronously and returns a coroutine."""
212+
213+
def __call__(
214+
self,
215+
m: MelleaSession | None = None,
216+
model_options: dict | None = None,
217+
*args: P.args,
218+
**kwargs: P.kwargs,
219+
) -> Coroutine[Any, Any, R]:
220+
"""Call the async generative slot.
221+
222+
Args:
223+
m: MelleaSession: A mellea session (optional, uses context if None)
224+
**kwargs: Additional Kwargs to be passed to the func
225+
226+
Returns:
227+
Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function
228+
"""
229+
if m is None:
230+
m = get_session()
231+
slot_copy = deepcopy(self)
232+
arguments = bind_function_arguments(self._function._func, *args, **kwargs)
233+
if arguments:
234+
for key, val in arguments.items():
235+
annotation = get_annotation(slot_copy._function._func, key, val)
236+
slot_copy._arguments.append(Argument(annotation, key, val))
237+
238+
response_model = create_response_format(self._function._func)
239+
240+
# AsyncGenerativeSlots are used with async functions. In order to support that behavior,
241+
# they must return a coroutine object.
242+
async def __async_call__():
243+
# Use the async act func so that control flow doesn't get stuck here in async event loops.
244+
response = await m.aact(
245+
slot_copy, format=response_model, model_options=model_options
246+
)
247+
248+
function_response: FunctionResponse[R] = response_model.model_validate_json(
249+
response.value # type: ignore
250+
)
251+
return function_response.result
252+
253+
return __async_call__()
254+
255+
210256
def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
211257
"""Convert a function into an AI-powered function.
212258
@@ -216,6 +262,8 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
216262
that function's behavior. The output is guaranteed to match the return type
217263
annotation using structured outputs and automatic validation.
218264
265+
Note: Works with async functions as well.
266+
219267
Tip: Write the function and docstring in the most Pythonic way possible, not
220268
like a prompt. This ensures the function is well-documented, easily understood,
221269
and familiar to any Python developer. The more natural and conventional your
@@ -248,7 +296,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
248296
... estimated_hours: float
249297
>>>
250298
>>> @generative
251-
... def create_project_tasks(project_desc: str, count: int) -> List[Task]:
299+
... async def create_project_tasks(project_desc: str, count: int) -> List[Task]:
252300
... '''Generate a list of realistic tasks for a project.
253301
...
254302
... Args:
@@ -260,7 +308,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
260308
... '''
261309
... ...
262310
>>>
263-
>>> tasks = create_project_tasks(session, "Build a web app", 5)
311+
>>> tasks = await create_project_tasks(session, "Build a web app", 5)
264312
265313
>>> @generative
266314
... def analyze_code_quality(code: str) -> Dict[str, Any]:
@@ -304,8 +352,46 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
304352
>>>
305353
>>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?")
306354
"""
307-
return GenerativeSlot(func)
355+
if inspect.iscoroutinefunction(func):
356+
return AsyncGenerativeSlot(func)
357+
else:
358+
return GenerativeSlot(func)
308359

309360

310361
# Export the decorator as the interface
311362
__all__ = ["generative"]
363+
364+
365+
if __name__ == "__main__":
366+
from mellea import start_session
367+
368+
with start_session():
369+
370+
async def asyncly() -> int: ...
371+
372+
out = asyncly()
373+
374+
@generative
375+
async def test_async(num: int) -> bool: ...
376+
377+
@generative
378+
def test_sync(truthy: bool) -> int: ...
379+
380+
print("running sync")
381+
print(test_sync(m=None, model_options=None, truthy=False))
382+
383+
async def runmany():
384+
print(await test_async(m=None, model_options=None, num=6))
385+
print(await test_async(m=None, model_options=None, num=4))
386+
print(await test_async(m=None, model_options=None, num=5))
387+
388+
coros = [
389+
test_async(m=None, model_options=None, num=1),
390+
test_async(m=None, model_options=None, num=2),
391+
test_async(m=None, model_options=None, num=3),
392+
]
393+
results = await asyncio.gather(*coros)
394+
print(results)
395+
396+
print("running async")
397+
asyncio.run(runmany())

test/stdlib_basics/test_genslot.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import asyncio
12
import pytest
23
from typing import Literal
34
from mellea import generative, start_session
5+
from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot
46

57

68
@generative
@@ -10,6 +12,8 @@ def classify_sentiment(text: str) -> Literal["positive", "negative"]: ...
1012
@generative
1113
def write_me_an_email() -> str: ...
1214

15+
@generative
16+
async def async_write_short_sentence(topic: str) -> str: ...
1317

1418
@pytest.fixture(scope="function")
1519
def session():
@@ -29,6 +33,7 @@ def test_gen_slot_output(classify_sentiment_output):
2933

3034

3135
def test_func(session):
36+
assert isinstance(write_me_an_email, GenerativeSlot) and not isinstance(write_me_an_email, AsyncGenerativeSlot)
3237
write_email_component = write_me_an_email(session)
3338
assert isinstance(write_email_component, str)
3439

@@ -43,5 +48,18 @@ def test_gen_slot_logs(classify_sentiment_output, session):
4348
assert isinstance(last_prompt, dict)
4449
assert set(last_prompt.keys()) == {"role", "content", "images"}
4550

51+
async def test_async_gen_slot(session):
52+
assert isinstance(async_write_short_sentence, AsyncGenerativeSlot)
53+
54+
r1 = async_write_short_sentence(session, topic="cats")
55+
r2 = async_write_short_sentence(session, topic="dogs")
56+
57+
r3 = await async_write_short_sentence(session, topic="fish")
58+
results = await asyncio.gather(r1, r2)
59+
60+
assert isinstance(r3, str)
61+
assert len(results) == 2
62+
63+
4664
if __name__ == "__main__":
4765
pytest.main([__file__])

0 commit comments

Comments
 (0)