11"""A method to generate outputs based on python functions and a Generative Slot function."""
22
3+ import asyncio
34import functools
45import inspect
5- from collections .abc import Callable
6+ from collections .abc import Callable , Coroutine
67from copy import deepcopy
78from typing import Any , Generic , ParamSpec , TypedDict , TypeVar , get_type_hints
89
@@ -168,14 +169,13 @@ def __call__(
168169 **kwargs: Additional Kwargs to be passed to the func.
169170
170171 Returns:
171- ModelOutputThunk: Output with generated Thunk.
172+ R: an object with the original return type of the function
172173 """
173174 if m is None :
174175 m = get_session ()
175176 slot_copy = deepcopy (self )
176177 arguments = bind_function_arguments (self ._function ._func , * args , ** kwargs )
177178 if arguments :
178- # slot_copy._arguments = []
179179 for key , val in arguments .items ():
180180 annotation = get_annotation (slot_copy ._function ._func , key , val )
181181 slot_copy ._arguments .append (Argument (annotation , key , val ))
@@ -207,6 +207,52 @@ def format_for_llm(self) -> TemplateRepresentation:
207207 )
208208
209209
210+ class AsyncGenerativeSlot (GenerativeSlot , Generic [P , R ]):
211+ """A generative slot component that generates asynchronously and returns a coroutine."""
212+
213+ def __call__ (
214+ self ,
215+ m : MelleaSession | None = None ,
216+ model_options : dict | None = None ,
217+ * args : P .args ,
218+ ** kwargs : P .kwargs ,
219+ ) -> Coroutine [Any , Any , R ]:
220+ """Call the async generative slot.
221+
222+ Args:
223+ m: MelleaSession: A mellea session (optional, uses context if None)
224+ **kwargs: Additional Kwargs to be passed to the func
225+
226+ Returns:
227+ Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function
228+ """
229+ if m is None :
230+ m = get_session ()
231+ slot_copy = deepcopy (self )
232+ arguments = bind_function_arguments (self ._function ._func , * args , ** kwargs )
233+ if arguments :
234+ for key , val in arguments .items ():
235+ annotation = get_annotation (slot_copy ._function ._func , key , val )
236+ slot_copy ._arguments .append (Argument (annotation , key , val ))
237+
238+ response_model = create_response_format (self ._function ._func )
239+
240+ # AsyncGenerativeSlots are used with async functions. In order to support that behavior,
241+ # they must return a coroutine object.
242+ async def __async_call__ ():
243+ # Use the async act func so that control flow doesn't get stuck here in async event loops.
244+ response = await m .aact (
245+ slot_copy , format = response_model , model_options = model_options
246+ )
247+
248+ function_response : FunctionResponse [R ] = response_model .model_validate_json (
249+ response .value # type: ignore
250+ )
251+ return function_response .result
252+
253+ return __async_call__ ()
254+
255+
210256def generative (func : Callable [P , R ]) -> GenerativeSlot [P , R ]:
211257 """Convert a function into an AI-powered function.
212258
@@ -216,6 +262,8 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
216262 that function's behavior. The output is guaranteed to match the return type
217263 annotation using structured outputs and automatic validation.
218264
265+ Note: Works with async functions as well.
266+
219267 Tip: Write the function and docstring in the most Pythonic way possible, not
220268 like a prompt. This ensures the function is well-documented, easily understood,
221269 and familiar to any Python developer. The more natural and conventional your
@@ -248,7 +296,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
248296 ... estimated_hours: float
249297 >>>
250298 >>> @generative
251- ... def create_project_tasks(project_desc: str, count: int) -> List[Task]:
299+ ... async def create_project_tasks(project_desc: str, count: int) -> List[Task]:
252300 ... '''Generate a list of realistic tasks for a project.
253301 ...
254302 ... Args:
@@ -260,7 +308,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
260308 ... '''
261309 ... ...
262310 >>>
263- >>> tasks = create_project_tasks(session, "Build a web app", 5)
311+ >>> tasks = await create_project_tasks(session, "Build a web app", 5)
264312
265313 >>> @generative
266314 ... def analyze_code_quality(code: str) -> Dict[str, Any]:
@@ -304,8 +352,46 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]:
304352 >>>
305353 >>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?")
306354 """
307- return GenerativeSlot (func )
355+ if inspect .iscoroutinefunction (func ):
356+ return AsyncGenerativeSlot (func )
357+ else :
358+ return GenerativeSlot (func )
308359
309360
310361# Export the decorator as the interface
311362__all__ = ["generative" ]
363+
364+
365+ if __name__ == "__main__" :
366+ from mellea import start_session
367+
368+ with start_session ():
369+
370+ async def asyncly () -> int : ...
371+
372+ out = asyncly ()
373+
374+ @generative
375+ async def test_async (num : int ) -> bool : ...
376+
377+ @generative
378+ def test_sync (truthy : bool ) -> int : ...
379+
380+ print ("running sync" )
381+ print (test_sync (m = None , model_options = None , truthy = False ))
382+
383+ async def runmany ():
384+ print (await test_async (m = None , model_options = None , num = 6 ))
385+ print (await test_async (m = None , model_options = None , num = 4 ))
386+ print (await test_async (m = None , model_options = None , num = 5 ))
387+
388+ coros = [
389+ test_async (m = None , model_options = None , num = 1 ),
390+ test_async (m = None , model_options = None , num = 2 ),
391+ test_async (m = None , model_options = None , num = 3 ),
392+ ]
393+ results = await asyncio .gather (* coros )
394+ print (results )
395+
396+ print ("running async" )
397+ asyncio .run (runmany ())
0 commit comments