diff --git a/docs/tutorial.md b/docs/tutorial.md index 3640300f..933b7b76 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -21,6 +21,7 @@ - [Chapter 10: Prompt Engineering for Mellea](#chapter-10-prompt-engineering-for-m) - [Custom Templates](#custom-templates) - [Chapter 11: Tool Calling](#chapter-11-tool-calling) +- [Chapter 12: Asynchronicity](#chapter-12-asynchronicity) - [Appendix: Contributing to Melles](#appendix-contributing-to-mellea) ## Chapter 1: What Is Generative Programming @@ -943,6 +944,23 @@ or the entire last turn (user query + assistant response): print(m.ctx.last_turn()) ``` +You can also use `session.clone()` to create a copy of a given session with its context at given point in time. This allows you to make multiple generation requests with the same objects in your context: +```python +m = start_session(ctx=ChatContext()) +m.instruct("Multiply 2x2.") + +m1 = m.clone() +m2 = m.clone() + +# Need to run this code in an async event loop. +co1 = m1.ainstruct("Multiply that by 3") +co2 = m2.ainstruct("Multiply that by 5") + +print(await co1) # 12 +print(await co2) # 20 +``` +In the above example, both requests have `Multiply 2x2` and the LLM's response to that (presumably `4`) in their context. By cloning the session, the new requests both operate independently on that context to get the correct answers to 4 x 3 and 4 x 5. + ## Chapter 8: Implementing Agents > **Definition:** An *agent* is a generative program in which an LLM determines the control flow of the program. @@ -1317,6 +1335,59 @@ assert "web_search" in output.tool_calls result = output.tool_calls["web_search"].call_func() ``` +## Chapter 12: Asynchronicity +Mellea supports asynchronous behavior in several ways: asynchronous functions and asynchronous event loops in synchronous functions. + +### Asynchronous Functions: +`MelleaSession`s have asynchronous functions that work just like regular async functions in python. These async session functions mirror their synchronous counterparts: +```python +m = start_session() +result = await m.ainstruct("Write your instruction here!") +``` + +However, if you want to run multiple async functions at the same time, you need to be careful with your context. By default, `MelleaSession`s use a `SimpleContext` that has no history. This will work just fine when running multiple async requests at once: +```python +m = start_session() +coroutines = [] + +for i in range(5): + coroutines.append(m.ainstruct(f"Write a math problem using {i}")) + +results = await asyncio.gather(*coroutines) +``` + +If you try to use a `ChatContext`, you will need to await between each request so that the context can be properly modified: +```python +m = start_session(ctx=ChatContext()) + +result = await m.ainstruct("Write a short fairy tale.") +print(result) + +main_character = await m.ainstruct("Who is the main character of the previous fairy tail?") +print(main_character) +``` + +Otherwise, you're requests will use outdated contexts that don't have the messages you expect. For example, +```python +m = start_session(ctx=ChatContext()) + +co1 = m.ainstruct("Write a very long math problem.") # Start first request. +co2 = m.ainstruct("Solve the math problem.") # Start second request with an empty context. + +results = await asyncio.gather(co1, co2) +for result in results: + print(result) # Neither request had anything in its context. + +print(m.ctx) # Only shows the operations from the second request. +``` + +Additionally, see [Chapter 7: Context Management](#chapter-7-on-context-management) for an example of how to use `session.clone()` to avoid these context issues. + +### Asynchronicity in Synchronous Functions +Mellea utilizes asynchronicity internally. When you call `m.instruct`, you are using synchronous code that executes an asynchronous request to an LLM to generate the result. For a single request, this won't cause any differences in execution speed. + +When using `SamplingStrategy`s or during validation, Mellea can speed up the execution time of your program by generating multiple results and validating those results against multiple requirements simultaneously. Whether you use `m.instruct` or the asynchronous `m.ainstruct`, Mellea will attempt to speed up your requests by dispatching those requests as quickly as possible and asynchronously awaiting the results. + ## Appendix: Contributing to Mellea ### Contributor Guide: Requirements and Verifiers diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index 47607b53..79c81d06 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -40,6 +40,7 @@ def act( context: Context, backend: Backend, *, + requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, @@ -54,6 +55,7 @@ def act( context: Context, backend: Backend, *, + requirements: list[Requirement] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, @@ -88,10 +90,10 @@ def act( tool_calls: if true, tool calling is enabled. Returns: - A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ out = _run_async_in_thread( - _act( + aact( action, context, backend, @@ -101,13 +103,352 @@ def act( format=format, model_options=model_options, tool_calls=tool_calls, + silence_context_type_warning=True, # We can safely silence this here since it's in a sync function. + ) # type: ignore[call-overload] + # Mypy doesn't like the bool for return_sampling_results. + ) + + return out + + +@overload +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: ... + + +@overload +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> SamplingResult: ... + + +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context] | SamplingResult: + """Generates from an instruction. + + Args: + description: The description of the instruction. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + requirements: A list of requirements that the instruction can be validated against. + icl_examples: A list of in-context-learning examples that the instruction can be validated against. + grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. + user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. + prefix: A prefix string or ContentBlock to use when generating the instruction. + output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: If set, the BaseModel to use for constrained decoding. + model_options: Additional model options, which will upsert into the model/backend's defaults. + tool_calls: If true, tool calling is enabled. + images: A list of images to be used in the instruction or None if none. + + Returns: + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + requirements = [] if requirements is None else requirements + icl_examples = [] if icl_examples is None else icl_examples + grounding_context = dict() if grounding_context is None else grounding_context + + images = _parse_and_clean_image_args(images) + + # All instruction options are forwarded to create a new Instruction object. + i = Instruction( + description=description, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + images=images, + ) + + return act( + i, + context=context, + backend=backend, + requirements=i.requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore[call-overload] + + +def chat( + content: str, + context: Context, + backend: Backend, + *, + role: Message.Role = "user", + images: list[ImageBlock] | list[PILImage.Image] | None = None, + user_variables: dict[str, str] | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[Message, Context]: + """Sends a simple chat message and returns the response. Adds both messages to the Context.""" + if user_variables is not None: + content_resolved = Instruction.apply_user_dict_from_jinja( + user_variables, content + ) + else: + content_resolved = content + images = _parse_and_clean_image_args(images) + user_message = Message(role=role, content=content_resolved, images=images) + + result, new_ctx = act( + user_message, + context=context, + backend=backend, + strategy=None, # Explicitly pass `None` since this can't pass requirements. + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + parsed_assistant_message = result.parsed_repr + assert isinstance(parsed_assistant_message, Message) + + return parsed_assistant_message, new_ctx + + +def validate( + reqs: Requirement | list[Requirement], + context: Context, + backend: Backend, + *, + output: CBlock | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] + | None = None, # TODO: Can we get rid of gen logs here and in act? + input: CBlock | None = None, +) -> list[ValidationResult]: + """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" + # Run everything in the specific event loop for this session. + + out = _run_async_in_thread( + avalidate( + reqs=reqs, + context=context, + backend=backend, + output=output, + format=format, + model_options=model_options, + generate_logs=generate_logs, + input=input, ) ) + # Wait for and return the result. return out -async def _act( +def query( + obj: Any, + query: str, + context: Context, + backend: Backend, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: + """Query method for retrieving information from an object. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + query: The string representing the query to be executed against the object. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + format: format for output parsing. + model_options: Model options to pass to the backend. + tool_calls: If true, the model may make tool calls. Defaults to False. + + Returns: + ModelOutputThunk: The result of the query as processed by the backend. + """ + if not isinstance(obj, MObjectProtocol): + obj = mify(obj) + + assert isinstance(obj, MObjectProtocol) + q = obj.get_query_object(query) + + answer = act( + q, + context=context, + backend=backend, + strategy=None, # Explicitly pass `None` since this can't pass requirements. + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + return answer + + +def transform( + obj: Any, + transformation: str, + context: Context, + backend: Backend, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, +) -> tuple[ModelOutputThunk | Any, Context]: + """Transform method for creating a new object with the transformation applied. + + Args: + obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + transformation: The string representing the query to be executed against the object. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + format: format for output parsing; usually not needed with transform. + model_options: Model options to pass to the backend. + + Returns: + (ModelOutputThunk | Any, Context): The result of the transformation as processed by the backend. If no tools were called, + the return type will be always be (ModelOutputThunk, Context). If a tool was called, the return type will be the return type + of the function called, usually the type of the object passed in. + """ + if not isinstance(obj, MObjectProtocol): + obj = mify(obj) + + assert isinstance(obj, MObjectProtocol) + t = obj.get_transform_object(transformation) + + # Check that your model / backend supports tool calling. + # This might throw an error when tools are provided but can't be handled by one or the other. + transformed, new_ctx = act( + t, + context=context, + backend=backend, + strategy=None, # Explicitly pass `None` since this can't pass requirements. + format=format, + model_options=model_options, + tool_calls=True, + ) + + tools = _call_tools(transformed, backend) + + # Transform only supports calling one tool call since it cannot currently synthesize multiple outputs. + # Attempt to choose the best one to call. + chosen_tool: ToolMessage | None = None + if len(tools) == 1: + # Only one function was called. Choose that one. + chosen_tool = tools[0] + + elif len(tools) > 1: + for output in tools: + if type(output._tool_output) is type(obj): + chosen_tool = output + break + + if chosen_tool is None: + chosen_tool = tools[0] + + FancyLogger.get_logger().warning( + f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" + # type: ignore + ) + + if chosen_tool: + # Tell the user the function they should've called if no generated values were added. + if len(chosen_tool._tool.args.keys()) == 0: + FancyLogger.get_logger().warning( + f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" + ) + + new_ctx.add(chosen_tool) + FancyLogger.get_logger().info( + "added a tool message from transform to the context" + ) + return chosen_tool._tool_output, new_ctx + + return transformed, new_ctx + + +@overload +async def aact( + action: Component, + context: Context, + backend: Backend, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + silence_context_type_warning: bool = False, +) -> tuple[ModelOutputThunk, Context]: ... + + +@overload +async def aact( + action: Component, + context: Context, + backend: Backend, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + silence_context_type_warning: bool = False, +) -> SamplingResult: ... + + +async def aact( action: Component, context: Context, backend: Backend, @@ -118,6 +459,7 @@ async def _act( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, + silence_context_type_warning: bool = False, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. @@ -131,10 +473,17 @@ async def _act( format: if set, the BaseModel to use for constrained decoding. model_options: additional model options, which will upsert into the model/backend's defaults. tool_calls: if true, tool calling is enabled. + silence_context_type_warning: if called directly from an asynchronous function, will log a warning if not using a SimpleContext Returns: - A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ + if not silence_context_type_warning and not isinstance(context, SimpleContext): + FancyLogger().get_logger().warning( + "Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests." + "\nSee the async section of the tutorial: https://github.com/generative-computing/mellea/blob/main/docs/tutorial.md#chapter-12-asynchronicity" + ) + sampling_result: SamplingResult | None = None generate_logs: list[GenerateLog] = [] @@ -165,7 +514,8 @@ async def _act( generate_logs.append(result._generate_log) else: - # if there is a reason to sample, use the sampling strategy. + # Always sample if a strategy is provided, even if no requirements were provided. + # Some sampling strategies don't use requirements or set them when instantiated. sampling_result = await strategy.sample( action, @@ -200,7 +550,7 @@ async def _act( @overload -def instruct( +async def ainstruct( description: str, context: Context, backend: Backend, @@ -221,7 +571,7 @@ def instruct( @overload -def instruct( +async def ainstruct( description: str, context: Context, backend: Backend, @@ -241,7 +591,7 @@ def instruct( ) -> SamplingResult: ... -def instruct( +async def ainstruct( description: str, context: Context, backend: Backend, @@ -277,6 +627,9 @@ def instruct( model_options: Additional model options, which will upsert into the model/backend's defaults. tool_calls: If true, tool calling is enabled. images: A list of images to be used in the instruction or None if none. + + Returns: + A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ requirements = [] if requirements is None else requirements icl_examples = [] if icl_examples is None else icl_examples @@ -296,7 +649,7 @@ def instruct( images=images, ) - return act( + return await aact( i, context=context, backend=backend, @@ -309,7 +662,7 @@ def instruct( ) # type: ignore[call-overload] -def chat( +async def achat( content: str, context: Context, backend: Backend, @@ -331,11 +684,11 @@ def chat( images = _parse_and_clean_image_args(images) user_message = Message(role=role, content=content_resolved, images=images) - result, new_ctx = act( + result, new_ctx = await aact( user_message, context=context, backend=backend, - strategy=None, + strategy=None, # Explicitly pass `None` since this can't pass requirements. format=format, model_options=model_options, tool_calls=tool_calls, @@ -346,39 +699,7 @@ def chat( return parsed_assistant_message, new_ctx -def validate( - reqs: Requirement | list[Requirement], - context: Context, - backend: Backend, - *, - output: CBlock | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - generate_logs: list[GenerateLog] - | None = None, # TODO: Can we get rid of gen logs here and in act? - input: CBlock | None = None, -) -> list[ValidationResult]: - """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - # Run everything in the specific event loop for this session. - - out = _run_async_in_thread( - _validate( - reqs=reqs, - context=context, - backend=backend, - output=output, - format=format, - model_options=model_options, - generate_logs=generate_logs, - input=input, - ) - ) - - # Wait for and return the result. - return out - - -async def _validate( +async def avalidate( reqs: Requirement | list[Requirement], context: Context, backend: Backend, @@ -435,7 +756,7 @@ async def _validate( return rvs -def query( +async def aquery( obj: Any, query: str, context: Context, @@ -465,11 +786,11 @@ def query( assert isinstance(obj, MObjectProtocol) q = obj.get_query_object(query) - answer = act( + answer = await aact( q, context=context, backend=backend, - strategy=None, + strategy=None, # Explicitly pass `None` since this can't pass requirements. format=format, model_options=model_options, tool_calls=tool_calls, @@ -477,7 +798,7 @@ def query( return answer -def transform( +async def atransform( obj: Any, transformation: str, context: Context, @@ -509,11 +830,11 @@ def transform( # Check that your model / backend supports tool calling. # This might throw an error when tools are provided but can't be handled by one or the other. - transformed, new_ctx = act( + transformed, new_ctx = await aact( t, context=context, backend=backend, - strategy=None, + strategy=None, # Explicitly pass `None` since this can't pass requirements. format=format, model_options=model_options, tool_calls=True, diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 1e822871..e2b2d57c 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -1,8 +1,9 @@ """A method to generate outputs based on python functions and a Generative Slot function.""" +import asyncio import functools import inspect -from collections.abc import Callable +from collections.abc import Callable, Coroutine from copy import deepcopy from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints @@ -168,14 +169,13 @@ def __call__( **kwargs: Additional Kwargs to be passed to the func. Returns: - ModelOutputThunk: Output with generated Thunk. + R: an object with the original return type of the function """ if m is None: m = get_session() slot_copy = deepcopy(self) arguments = bind_function_arguments(self._function._func, *args, **kwargs) if arguments: - # slot_copy._arguments = [] for key, val in arguments.items(): annotation = get_annotation(slot_copy._function._func, key, val) slot_copy._arguments.append(Argument(annotation, key, val)) @@ -207,6 +207,54 @@ def format_for_llm(self) -> TemplateRepresentation: ) +class AsyncGenerativeSlot(GenerativeSlot, Generic[P, R]): + """A generative slot component that generates asynchronously and returns a coroutine.""" + + def __call__( + self, + m: MelleaSession | None = None, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> Coroutine[Any, Any, R]: + """Call the async generative slot. + + Args: + m: MelleaSession: A mellea session (optional, uses context if None) + model_options: Model options to pass to the backend. + *args: Additional args to be passed to the func. + **kwargs: Additional Kwargs to be passed to the func + + Returns: + Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function + """ + if m is None: + m = get_session() + slot_copy = deepcopy(self) + arguments = bind_function_arguments(self._function._func, *args, **kwargs) + if arguments: + for key, val in arguments.items(): + annotation = get_annotation(slot_copy._function._func, key, val) + slot_copy._arguments.append(Argument(annotation, key, val)) + + response_model = create_response_format(self._function._func) + + # AsyncGenerativeSlots are used with async functions. In order to support that behavior, + # they must return a coroutine object. + async def __async_call__() -> R: + # Use the async act func so that control flow doesn't get stuck here in async event loops. + response = await m.aact( + slot_copy, format=response_model, model_options=model_options + ) + + function_response: FunctionResponse[R] = response_model.model_validate_json( + response.value # type: ignore + ) + return function_response.result + + return __async_call__() + + def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: """Convert a function into an AI-powered function. @@ -216,6 +264,8 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: that function's behavior. The output is guaranteed to match the return type annotation using structured outputs and automatic validation. + Note: Works with async functions as well. + Tip: Write the function and docstring in the most Pythonic way possible, not like a prompt. This ensures the function is well-documented, easily understood, and familiar to any Python developer. The more natural and conventional your @@ -248,7 +298,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... estimated_hours: float >>> >>> @generative - ... def create_project_tasks(project_desc: str, count: int) -> List[Task]: + ... async def create_project_tasks(project_desc: str, count: int) -> List[Task]: ... '''Generate a list of realistic tasks for a project. ... ... Args: @@ -260,7 +310,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... ''' ... ... >>> - >>> tasks = create_project_tasks(session, "Build a web app", 5) + >>> tasks = await create_project_tasks(session, "Build a web app", 5) >>> @generative ... def analyze_code_quality(code: str) -> Dict[str, Any]: @@ -304,7 +354,10 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: >>> >>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?") """ - return GenerativeSlot(func) + if inspect.iscoroutinefunction(func): + return AsyncGenerativeSlot(func) + else: + return GenerativeSlot(func) # Export the decorator as the interface diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 9f374fa9..6520a7da 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -161,7 +161,7 @@ async def sample( await result.avalue() # validation pass - val_scores_co = mfuncs._validate( + val_scores_co = mfuncs.avalidate( reqs=reqs, context=result_ctx, backend=backend, diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index d0bdf341..59c402b2 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -127,7 +127,7 @@ async def sample( result = sampled_results[i] next_action = sampled_actions[i] - val_scores_co = mfuncs._validate( + val_scores_co = mfuncs.avalidate( reqs=reqs, context=result_ctx, backend=backend, diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 6aa3204b..2a63a71a 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -3,6 +3,7 @@ from __future__ import annotations import contextvars +from copy import copy from typing import Any, Literal, overload from PIL import Image as PILImage @@ -29,6 +30,7 @@ from mellea.stdlib.chat import Message from mellea.stdlib.requirement import Requirement, ValidationResult from mellea.stdlib.sampling import SamplingResult, SamplingStrategy +from mellea.stdlib.sampling.base import RejectionSamplingStrategy # Global context variable for the context session _context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( @@ -175,11 +177,9 @@ def __init__(self, backend: Backend, ctx: Context | None = None): Args: backend (Backend): This is always required. ctx (Context): The way in which the model's context will be managed. By default, each interaction with the model is a stand-alone interaction, so we use SimpleContext as the default. - model_options (Optional[dict]): model options, which will upsert into the model/backend's defaults. """ self.backend = backend self.ctx: Context = ctx if ctx is not None else SimpleContext() - self._backend_stack: list[tuple[Backend, dict | None]] = [] self._session_logger = FancyLogger.get_logger() self._context_token = None @@ -195,30 +195,36 @@ def __exit__(self, exc_type, exc_val, exc_tb): _context_session.reset(self._context_token) self._context_token = None - def _push_model_state(self, new_backend: Backend, new_model_opts: dict): - """The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`. + def __copy__(self): + """Use self.clone. Copies the current session but keeps references to the backend and context.""" + new = MelleaSession(backend=self.backend, ctx=self.ctx) + new._session_logger = self._session_logger + # Explicitly don't copy over the _context_token. - Question: should this logic be moved into context? I really want to keep `Session` as simple as possible... see true motivation in the docstring for the class. - """ - self._backend_stack.append((self.backend, self.model_options)) - self.backend = new_backend - self.opts = new_model_opts - - def _pop_model_state(self) -> bool: - """Pops the model state. + return new - The backend and model options used within a `Context` can be temporarily changed by pushing and popping from the model state. - This function restores the model's previous backend and model_opts from the `self._backend_stack`. + def clone(self): + """Useful for running multiple generation requests while keeping the context at a given point in time. - Question: should this logic be moved into context? I really want to keep `Session` as simple as possible... see true motivation in the docstring for the class. + Returns: + a copy of the current session. Keeps the context, backend, and session logger. + + Examples: + >>> from mellea import start_session + >>> m = start_session() + >>> m.instruct("What is 2x2?") + >>> + >>> m1 = m.clone() + >>> out = m1.instruct("Multiply that by 2") + >>> print(out) + ... 8 + >>> + >>> m2 = m.clone() + >>> out = m2.instruct("Multiply that by 3") + >>> print(out) + ... 12 """ - try: - b, b_model_opts = self._backend_stack.pop() - self.backend = b - self.model_options = b_model_opts - return True - except Exception: - return False + return copy(self) def reset(self): """Reset the context state.""" @@ -227,7 +233,6 @@ def reset(self): def cleanup(self) -> None: """Clean up session resources.""" self.reset() - self._backend_stack.clear() if hasattr(self.backend, "close"): self.backend.close() # type: ignore @@ -236,7 +241,8 @@ def act( self, action: Component, *, - strategy: SamplingStrategy | None = None, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -248,14 +254,25 @@ def act( self, action: Component, *, - strategy: SamplingStrategy | None = None, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, ) -> SamplingResult: ... - def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: # noqa: D417 + def act( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. Args: @@ -270,11 +287,25 @@ def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: Returns: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - result, context = mfuncs.act( - action, context=self.ctx, backend=self.backend, **kwargs - ) - self.ctx = context - return result + r = mfuncs.act( + action, + context=self.ctx, + backend=self.backend, + requirements=requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore + + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + result, context = r + self.ctx = context + return result @overload def instruct( @@ -288,7 +319,7 @@ def instruct( user_variables: dict[str, str] | None = None, prefix: str | CBlock | None = None, output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[False] = False, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -307,14 +338,30 @@ def instruct( user_variables: dict[str, str] | None = None, prefix: str | CBlock | None = None, output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), return_sampling_results: Literal[True], format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, ) -> SamplingResult: ... - def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingResult: # noqa: D417 + def instruct( + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: """Generates from an instruction. Args: @@ -333,13 +380,28 @@ def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingRes images: A list of images to be used in the instruction or None if none. """ r = mfuncs.instruct( - description, context=self.ctx, backend=self.backend, **kwargs + description, + context=self.ctx, + backend=self.backend, + images=images, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + strategy=strategy, + return_sampling_results=return_sampling_results, # type: ignore + format=format, + model_options=model_options, + tool_calls=tool_calls, ) if isinstance(r, SamplingResult): self.ctx = r.result_ctx return r else: + # It's a tuple[ModelOutputThunk, Context]. result, context = r self.ctx = context return result @@ -458,6 +520,290 @@ def transform( self.ctx = context return result + @overload + async def aact( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: ... + + @overload + async def aact( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> SamplingResult: ... + + async def aact( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: + """Runs a generic action, and adds both the action and the result to the context. + + Args: + action: the Component from which to generate. + requirements: used as additional requirements when a sampling strategy is provided + strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: if set, the BaseModel to use for constrained decoding. + model_options: additional model options, which will upsert into the model/backend's defaults. + tool_calls: if true, tool calling is enabled. + + Returns: + A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + r = await mfuncs.aact( + action, + context=self.ctx, + backend=self.backend, + requirements=requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore + + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + result, context = r + self.ctx = context + return result + + @overload + async def ainstruct( + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: ... + + @overload + async def ainstruct( + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> SamplingResult: ... + + async def ainstruct( + self, + description: str, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: + """Generates from an instruction. + + Args: + description: The description of the instruction. + requirements: A list of requirements that the instruction can be validated against. + icl_examples: A list of in-context-learning examples that the instruction can be validated against. + grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. + user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. + prefix: A prefix string or ContentBlock to use when generating the instruction. + output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: If set, the BaseModel to use for constrained decoding. + model_options: Additional model options, which will upsert into the model/backend's defaults. + tool_calls: If true, tool calling is enabled. + images: A list of images to be used in the instruction or None if none. + """ + r = await mfuncs.ainstruct( + description, + context=self.ctx, + backend=self.backend, + images=images, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + strategy=strategy, + return_sampling_results=return_sampling_results, # type: ignore + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + # It's a tuple[ModelOutputThunk, Context]. + result, context = r + self.ctx = context + return result + + async def achat( + self, + content: str, + role: Message.Role = "user", + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + user_variables: dict[str, str] | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> Message: + """Sends a simple chat message and returns the response. Adds both messages to the Context.""" + result, context = await mfuncs.achat( + content=content, + context=self.ctx, + backend=self.backend, + role=role, + images=images, + user_variables=user_variables, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + + self.ctx = context + return result + + async def avalidate( + self, + reqs: Requirement | list[Requirement], + *, + output: CBlock | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] | None = None, + input: CBlock | None = None, + ) -> list[ValidationResult]: + """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" + return await mfuncs.avalidate( + reqs=reqs, + context=self.ctx, + backend=self.backend, + output=output, + format=format, + model_options=model_options, + generate_logs=generate_logs, + input=input, + ) + + async def aquery( + self, + obj: Any, + query: str, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk: + """Query method for retrieving information from an object. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + query: The string representing the query to be executed against the object. + format: format for output parsing. + model_options: Model options to pass to the backend. + tool_calls: If true, the model may make tool calls. Defaults to False. + + Returns: + ModelOutputThunk: The result of the query as processed by the backend. + """ + result, context = await mfuncs.aquery( + obj=obj, + query=query, + context=self.ctx, + backend=self.backend, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + self.ctx = context + return result + + async def atransform( + self, + obj: Any, + transformation: str, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + ) -> ModelOutputThunk | Any: + """Transform method for creating a new object with the transformation applied. + + Args: + obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + transformation: The string representing the query to be executed against the object. + format: format for output parsing; usually not needed with transform. + model_options: Model options to pass to the backend. + + Returns: + ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, + the return type will be always be ModelOutputThunk. If a tool was called, the return type will be the return type + of the function called, usually the type of the object passed in. + """ + result, context = await mfuncs.atransform( + obj=obj, + transformation=transformation, + context=self.ctx, + backend=self.backend, + format=format, + model_options=model_options, + ) + self.ctx = context + return result + # ############################### # Convenience functions # ############################### diff --git a/pyproject.toml b/pyproject.toml index 4a02f614..bd1652bc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ dev = [ "ruff>=0.11.6", "pdm>=2.24.0", "pytest", + "pytest-asyncio", "mypy>=1.17.0", "python-semantic-release~=7.32", ] @@ -176,7 +177,7 @@ python_version = "3.10" markers = [ "qualitative: Marks the test as needing an exact output from an LLM; set by an ENV variable for CICD. All tests marked with this will xfail in CI/CD" ] - +asyncio_mode = "auto" # Don't require explicitly marking async tests. [tool.semantic_release] # for default values check: diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 4e3ab441..3f40d4cd 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -47,7 +47,7 @@ def test_system_prompt(session): print(result) @pytest.mark.qualitative -def test_constraint_alora(session, backend): +async def test_constraint_alora(session, backend): answer = session.instruct( "Corporate wants you to find the difference between these two strings: aaaaaaaaaa aaaaabaaaa. Be concise and don't write code to answer the question.", model_options={ @@ -55,18 +55,16 @@ def test_constraint_alora(session, backend): }, # Until aloras get a bit better, try not to abruptly end generation. ) - async def alora_generate(): - alora_output = backend.get_aloras()[ - 0 - ].generate_using_strings( - input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa", - response=str(answer), - constraint="The answer mention that there is a b in the middle of one of the strings but not the other.", - force_yn=False, # make sure that the alora naturally output Y and N without constrained generation - ) - await alora_output.avalue() - assert alora_output.value in ["Y", "N"], alora_output - asyncio.run(alora_generate()) + alora_output = backend.get_aloras()[ + 0 + ].generate_using_strings( + input="Find the difference between these two strings: aaaaaaaaaa aaaaabaaaa", + response=str(answer), + constraint="The answer mention that there is a b in the middle of one of the strings but not the other.", + force_yn=False, # make sure that the alora naturally output Y and N without constrained generation + ) + await alora_output.avalue() + assert alora_output.value in ["Y", "N"], alora_output @pytest.mark.qualitative def test_constraint_lora_with_requirement(session, backend): @@ -226,42 +224,38 @@ class Answer(pydantic.BaseModel): ) @pytest.mark.qualitative -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - asyncio.run(parallel_requests()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value @pytest.mark.qualitative -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": import pytest diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index 7999e4cf..a7f4879d 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -78,42 +78,38 @@ def is_happy(text: str) -> bool: assert h is True @pytest.mark.qualitative -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - asyncio.run(parallel_requests()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value @pytest.mark.qualitative -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": import pytest diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 806747f4..1362019a 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -131,46 +131,40 @@ class Answer(pydantic.BaseModel): ) -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), - model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), - model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - - asyncio.run(parallel_requests()) - - -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - - asyncio.run(avalue()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), + model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), + model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value + + +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index 77487c6c..d2e24970 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -144,54 +144,48 @@ class Email(pydantic.BaseModel): # assert False, f"formatting directive failed for {random_result.value}: {e.json()}" -def test_async_parallel_requests(m_session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = m_session.backend.generate_from_context( - CBlock("Say Hello."), SimpleContext(), model_options=model_opts - ) - mot2, _ = m_session.backend.generate_from_context( - CBlock("Say Goodbye!"),SimpleContext(), model_options=model_opts - ) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" +async def test_async_parallel_requests(m_session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = m_session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext(), model_options=model_opts + ) + mot2, _ = m_session.backend.generate_from_context( + CBlock("Say Goodbye!"),SimpleContext(), model_options=model_opts + ) - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), ( - "final val should contain the first streamed chunk" - ) - assert m2_final_val.startswith(m2_val), ( - "final val should contain the first streamed chunk" - ) + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() - asyncio.run(parallel_requests()) + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), ( + "final val should contain the first streamed chunk" + ) + assert m2_final_val.startswith(m2_val), ( + "final val should contain the first streamed chunk" + ) + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value -def test_async_avalue(m_session): - async def avalue(): - mot1, _ = m_session.backend.generate_from_context( - CBlock("Say Hello."), SimpleContext() - ) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(m_session): + mot1, _ = m_session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext() + ) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index 12ec10d3..907d4575 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -100,42 +100,38 @@ def test_generate_from_raw(session: MelleaSession): assert len(results) == len(prompts) @pytest.mark.qualitative -def test_async_parallel_requests(session): - async def parallel_requests(): - model_opts = {ModelOption.STREAM: True} - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) - - m1_val = None - m2_val = None - if not mot1.is_computed(): - m1_val = await mot1.astream() - if not mot2.is_computed(): - m2_val = await mot2.astream() - - assert m1_val is not None, "should be a string val after generation" - assert m2_val is not None, "should be a string val after generation" - - m1_final_val = await mot1.avalue() - m2_final_val = await mot2.avalue() - - # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response - # contains the full response. - assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" - assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" - - assert m1_final_val == mot1.value - assert m2_final_val == mot2.value - asyncio.run(parallel_requests()) +async def test_async_parallel_requests(session): + model_opts = {ModelOption.STREAM: True} + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value @pytest.mark.qualitative -def test_async_avalue(session): - async def avalue(): - mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) - m1_final_val = await mot1.avalue() - assert m1_final_val is not None - assert m1_final_val == mot1.value - asyncio.run(avalue()) +async def test_async_avalue(session): + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value if __name__ == "__main__": import pytest diff --git a/test/stdlib_basics/test_funcs.py b/test/stdlib_basics/test_funcs.py index f652eb98..189fb6eb 100644 --- a/test/stdlib_basics/test_funcs.py +++ b/test/stdlib_basics/test_funcs.py @@ -3,8 +3,10 @@ import pytest from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock -from mellea.stdlib.funcs import instruct +from mellea.stdlib.base import CBlock, ModelOutputThunk +from mellea.stdlib.chat import Message +from mellea.stdlib.funcs import instruct, aact, avalidate, ainstruct +from mellea.stdlib.requirement import req from mellea.stdlib.session import start_session @@ -33,5 +35,46 @@ def test_func_context(m_session): assert initial_ctx is not ctx assert ctx._data is out +async def test_aact(m_session): + initial_ctx = m_session.ctx + backend = m_session.backend + + out, ctx = await aact( + Message(role="user", content="hello"), + initial_ctx, + backend + ) + + assert initial_ctx is not ctx + assert ctx._data is out + +async def test_ainstruct(m_session): + initial_ctx = m_session.ctx + backend = m_session.backend + + out, ctx = await ainstruct( + "Write a sentence", + initial_ctx, + backend + ) + + assert initial_ctx is not ctx + assert ctx._data is out + +async def test_avalidate(m_session): + initial_ctx = m_session.ctx + backend = m_session.backend + + val_result = await avalidate( + reqs=[req("Be formal."), req("Avoid telling jokes.")], + context=initial_ctx, + backend=backend, + output=ModelOutputThunk("Here is an output.") + ) + + assert len(val_result) == 2 + assert val_result[0] is not None + + if __name__ == "__main__": pytest.main([__file__]) \ No newline at end of file diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index ebcace55..c9695260 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -1,6 +1,8 @@ +import asyncio import pytest from typing import Literal from mellea import generative, start_session +from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot @generative @@ -10,6 +12,8 @@ def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... @generative def write_me_an_email() -> str: ... +@generative +async def async_write_short_sentence(topic: str) -> str: ... @pytest.fixture(scope="function") def session(): @@ -29,6 +33,7 @@ def test_gen_slot_output(classify_sentiment_output): def test_func(session): + assert isinstance(write_me_an_email, GenerativeSlot) and not isinstance(write_me_an_email, AsyncGenerativeSlot) write_email_component = write_me_an_email(session) assert isinstance(write_email_component, str) @@ -43,5 +48,18 @@ def test_gen_slot_logs(classify_sentiment_output, session): assert isinstance(last_prompt, dict) assert set(last_prompt.keys()) == {"role", "content", "images"} +async def test_async_gen_slot(session): + assert isinstance(async_write_short_sentence, AsyncGenerativeSlot) + + r1 = async_write_short_sentence(session, topic="cats") + r2 = async_write_short_sentence(session, topic="dogs") + + r3 = await async_write_short_sentence(session, topic="fish") + results = await asyncio.gather(r1, r2) + + assert isinstance(r3, str) + assert len(results) == 2 + + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib_basics/test_requirement.py b/test/stdlib_basics/test_requirement.py index a1bef684..f569308d 100644 --- a/test/stdlib_basics/test_requirement.py +++ b/test/stdlib_basics/test_requirement.py @@ -7,15 +7,12 @@ ctx = ChatContext() ctx = ctx.add(ModelOutputThunk("test")) -def test_llmaj_validation_req_output_field(): +async def test_llmaj_validation_req_output_field(): m = start_session(ctx=ctx) req = Requirement("Must output test.") assert req._output is None - async def val(): - _ = await req.validate(m.backend,ctx=ctx) - asyncio.run(val()) - + _ = await req.validate(m.backend,ctx=ctx) assert req._output is None, "requirement's output shouldn't be updated during/after validation" def test_simple_validate_bool(): diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 67168a38..efcb51ed 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -1,10 +1,31 @@ +import asyncio import os import pytest -from mellea.stdlib.base import ModelOutputThunk +from mellea.backends.ollama import OllamaModelBackend +from mellea.backends.types import ModelOption +from mellea.stdlib.base import ChatContext, ModelOutputThunk +from mellea.stdlib.chat import Message from mellea.stdlib.session import start_session +# We edit the context type in the async tests below. Don't change the scope here. +@pytest.fixture(scope="function") +def m_session(gh_run): + if gh_run == 1: + m = start_session( + "ollama", + model_id="llama3.2:1b", + model_options={ModelOption.MAX_NEW_TOKENS: 5}, + ) + else: + m = start_session( + "ollama", + model_id="granite3.3:8b", + model_options={ModelOption.MAX_NEW_TOKENS: 5}, + ) + yield m + del m def test_start_session_watsonx(gh_run): if gh_run == 1: @@ -15,7 +36,6 @@ def test_start_session_watsonx(gh_run): assert isinstance(response, ModelOutputThunk) assert response.value is not None - def test_start_session_openai_with_kwargs(gh_run): if gh_run == 1: m = start_session( @@ -37,6 +57,73 @@ def test_start_session_openai_with_kwargs(gh_run): assert response.value is not None assert initial_ctx is not m.ctx +async def test_aact(m_session): + initial_ctx = m_session.ctx + out = await m_session.aact(Message(role="user", content="Hello!")) + assert m_session.ctx is not initial_ctx + assert out.value is not None + +async def test_ainstruct(m_session): + initial_ctx = m_session.ctx + out = await m_session.ainstruct("Write a sentence.") + assert m_session.ctx is not initial_ctx + assert out.value is not None + +async def test_async_await_with_chat_context(m_session): + m_session.ctx = ChatContext() + + m1 = Message(role="user", content="1") + m2 = Message(role="user", content="2") + r1 = await m_session.aact(m1, strategy=None) + r2 = await m_session.aact(m2, strategy=None) + + # This should be the order of these items in the session's context. + history = [r2, m2, r1, m1] + + ctx = m_session.ctx + for i in range(len(history)): + assert ctx.node_data is history[i] + ctx = ctx.previous_node + + # Ensure we made it back to the root. + assert ctx.is_root_node == True + +async def test_async_without_waiting_with_chat_context(m_session): + m_session.ctx = ChatContext() + + m1 = Message(role="user", content="1") + m2 = Message(role="user", content="2") + co1 = m_session.aact(m1) + co2 = m_session.aact(m2) + _, _ = await asyncio.gather(co2, co1) + + ctx = m_session.ctx + assert len(ctx.view_for_generation()) == 2 + +def test_session_copy_with_context_ops(m_session): + out = m_session.instruct("What is 2x2?") + main_ctx = m_session.ctx + + m1 = m_session.clone() + out1 = m1.instruct("Multiply by 3.") + + m2 = m_session.clone() + out2 = m2.instruct("Multiply by 4.") + + # Assert that each context is the correct one. + assert m_session.ctx is main_ctx + assert m_session.ctx is not m1.ctx + assert m_session.ctx is not m2.ctx + assert m1.ctx is not m2.ctx + + # Assert that node data is correct. + assert m_session.ctx.node_data is out + assert m1.ctx.node_data is out1 + assert m2.ctx.node_data is out2 + + # Assert that the new sessions still branch off the original one. + assert m1.ctx.previous_node.previous_node is m_session.ctx + assert m2.ctx.previous_node.previous_node is m_session.ctx if __name__ == "__main__": pytest.main([__file__]) diff --git a/uv.lock b/uv.lock index dc920ad6..f4831e93 100644 --- a/uv.lock +++ b/uv.lock @@ -394,6 +394,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/b8/3fe70c75fe32afc4bb507f75563d39bc5642255d1d94f1f23604725780bf/babel-2.17.0-py3-none-any.whl", hash = "sha256:4d0b53093fdfb4b21c92b5213dba5a1b23885afa8383709427046b21c366e5f2", size = 10182537, upload-time = "2025-02-01T15:17:37.39Z" }, ] +[[package]] +name = "backports-asyncio-runner" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8e/ff/70dca7d7cb1cbc0edb2c6cc0c38b65cba36cccc491eca64cabd5fe7f8670/backports_asyncio_runner-1.2.0.tar.gz", hash = "sha256:a5aa7b2b7d8f8bfcaa2b57313f70792df84e32a2a746f585213373f900b42162", size = 69893, upload-time = "2025-07-02T02:27:15.685Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/59/76ab57e3fe74484f48a53f8e337171b4a2349e506eabe136d7e01d059086/backports_asyncio_runner-1.2.0-py3-none-any.whl", hash = "sha256:0da0a936a8aeb554eccb426dc55af3ba63bcdc69fa1a600b5bb305413a4477b5", size = 12313, upload-time = "2025-07-02T02:27:14.263Z" }, +] + [[package]] name = "backports-tarfile" version = "1.2.0" @@ -2425,6 +2434,7 @@ dev = [ { name = "pre-commit" }, { name = "pylint" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "python-semantic-release" }, { name = "ruff" }, ] @@ -2483,6 +2493,7 @@ dev = [ { name = "pre-commit", specifier = ">=4.2.0" }, { name = "pylint", specifier = ">=3.3.4" }, { name = "pytest" }, + { name = "pytest-asyncio" }, { name = "python-semantic-release", specifier = "~=7.32" }, { name = "ruff", specifier = ">=0.11.6" }, ] @@ -4098,6 +4109,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/a8/a4/20da314d277121d6534b3a980b29035dcd51e6744bd79075a6ce8fa4eb8d/pytest-8.4.2-py3-none-any.whl", hash = "sha256:872f880de3fc3a5bdc88a11b39c9710c3497a547cfa9320bc3c5e62fbf272e79", size = 365750, upload-time = "2025-09-04T14:34:20.226Z" }, ] +[[package]] +name = "pytest-asyncio" +version = "1.2.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "backports-asyncio-runner", marker = "python_full_version < '3.11'" }, + { name = "pytest" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/42/86/9e3c5f48f7b7b638b216e4b9e645f54d199d7abbbab7a64a13b4e12ba10f/pytest_asyncio-1.2.0.tar.gz", hash = "sha256:c609a64a2a8768462d0c99811ddb8bd2583c33fd33cf7f21af1c142e824ffb57", size = 50119, upload-time = "2025-09-12T07:33:53.816Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/93/2fa34714b7a4ae72f2f8dad66ba17dd9a2c793220719e736dda28b7aec27/pytest_asyncio-1.2.0-py3-none-any.whl", hash = "sha256:8e17ae5e46d8e7efe51ab6494dd2010f4ca8dae51652aa3c8d55acf50bfb2e99", size = 15095, upload-time = "2025-09-12T07:33:52.639Z" }, +] + [[package]] name = "python-bidi" version = "0.6.6"