From 9a7bcb08cf1681297ad790256d3ce87994ab32e1 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Thu, 9 Oct 2025 15:46:30 -0400 Subject: [PATCH 1/9] feat: remove auto get session from context; add context, backend option to genslots --- mellea/stdlib/genslot.py | 277 ++++++++++++++++++++++++----- test/stdlib_basics/test_genslot.py | 25 ++- 2 files changed, 258 insertions(+), 44 deletions(-) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index e2b2d57c..481c3122 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -1,16 +1,28 @@ """A method to generate outputs based on python functions and a Generative Slot function.""" -import asyncio +import abc import functools import inspect -from collections.abc import Callable, Coroutine +from collections.abc import Awaitable, Callable, Coroutine from copy import deepcopy -from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints +from dataclasses import dataclass +from typing import ( + Any, + Generic, + ParamSpec, + TypedDict, + TypeVar, + Union, + get_type_hints, + overload, +) from pydantic import BaseModel, Field, create_model -from mellea.stdlib.base import Component, TemplateRepresentation -from mellea.stdlib.session import MelleaSession, get_session +import mellea.stdlib.funcs as mfuncs +from mellea.backends import Backend +from mellea.stdlib.base import Component, Context, TemplateRepresentation +from mellea.stdlib.session import MelleaSession P = ParamSpec("P") R = TypeVar("R") @@ -140,6 +152,21 @@ def __init__(self, func: Callable): self._function_dict: FunctionDict = describe_function(func) +@dataclass +class ExtractedKwargs: + """Used to extract the mellea args and original function args.""" + + f_kwargs: dict + session: MelleaSession | None = None + context: Context | None = None + backend: Backend | None = None + model_options: dict | None = None + + def __init__(self): + """Used to extract the mellea args and original function args.""" + self.f_kwargs = {} + + class GenerativeSlot(Component, Generic[P, R]): """A generative slot component.""" @@ -153,17 +180,14 @@ def __init__(self, func: Callable[P, R]): self._arguments: list[Argument] = [] functools.update_wrapper(self, func) - def __call__( - self, - m: MelleaSession | None = None, - model_options: dict | None = None, - *args: P.args, - **kwargs: P.kwargs, - ) -> R: + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: """Call the generative slot. Args: - m: MelleaSession: A mellea session (optional, uses context if None) + m: MelleaSession: A mellea session (optional; must set context and backend if None) + context: the Context object (optional; session must be set if None) + backend: the backend used for generation (optional: session must be set if None) model_options: Model options to pass to the backend. *args: Additional args to be passed to the func. **kwargs: Additional Kwargs to be passed to the func. @@ -171,24 +195,74 @@ def __call__( Returns: R: an object with the original return type of the function """ - if m is None: - m = get_session() - slot_copy = deepcopy(self) - arguments = bind_function_arguments(self._function._func, *args, **kwargs) - if arguments: - for key, val in arguments.items(): - annotation = get_annotation(slot_copy._function._func, key, val) - slot_copy._arguments.append(Argument(annotation, key, val)) - - response_model = create_response_format(self._function._func) + ... - response = m.act(slot_copy, format=response_model, model_options=model_options) + @staticmethod + def extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: + """Takes a mix of args and kwargs for both the generative slot and the original function and extracts them. - function_response: FunctionResponse[R] = response_model.model_validate_json( - response.value # type: ignore - ) + Returns: + ExtractedKwargs: a dataclass of the required args for mellea and the original function. + Either session or (backend, context) will be non-None. + """ + # Possible args for the generative slot. + extracted = ExtractedKwargs() + + # Args can only have mellea args. + for arg in args: + match arg: + case MelleaSession(): + extracted.session = arg + case Context(): + extracted.context = arg + case Backend(): + extracted.backend = arg + case dict(): + extracted.model_options = arg + + T = TypeVar("T") + + def _get_val_or_err(name: str, var: T | None, new_val: T) -> T: + """Returns the new_value if the original var is None, else raises a ValueError.""" + if var is None: + return new_val + else: + raise ValueError( + f"passed in multiple values of {name} to generative slot: {var}, {new_val}" + ) + + # Kwargs: + # - some / all of the mellea args + # - all of the function args (P.args and P.kwargs) + for key, val in kwargs.items(): + match key: + case "m": + extracted.session = _get_val_or_err("m", extracted.session, val) + case "context": + extracted.context = _get_val_or_err( + "context", extracted.context, val + ) + case "backend": + extracted.backend = _get_val_or_err( + "backend", extracted.backend, val + ) + case "model_options": + extracted.model_options = _get_val_or_err( + "model_options", extracted.model_options, val + ) + case _: + extracted.f_kwargs[key] = val + + # Need to check that either session is set or both backend and context are set; + # model_options can be None. + if extracted.session is None and ( + extracted.backend is None or extracted.context is None + ): + raise ValueError( + f"need to pass in a session or a (backend and context) to generative slot; got session({extracted.session}), backend({extracted.backend}), context({extracted.context})" + ) - return function_response.result + return extracted def parts(self): """Not implemented.""" @@ -207,31 +281,122 @@ def format_for_llm(self) -> TemplateRepresentation: ) +class SyncGenerativeSlot(GenerativeSlot, Generic[P, R]): + @overload + def __call__( + self, + context: Context, + backend: Backend, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> tuple[R, Context]: ... + + @overload + def __call__( + self, + m: MelleaSession, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> R: ... + + def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: + """Call the generative slot. + + Args: + m: MelleaSession: A mellea session (optional; must set context and backend if None) + context: the Context object (optional; session must be set if None) + backend: the backend used for generation (optional: session must be set if None) + model_options: Model options to pass to the backend. + *args: Additional args to be passed to the func. + **kwargs: Additional Kwargs to be passed to the func. + + Returns: + R: an object with the original return type of the function + """ + extracted = self.extract_args_and_kwargs(*args, **kwargs) + + slot_copy = deepcopy(self) + arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) + if arguments: + for key, val in arguments.items(): + annotation = get_annotation(slot_copy._function._func, key, val) + slot_copy._arguments.append(Argument(annotation, key, val)) + + response_model = create_response_format(self._function._func) + + response, context = None, None + if extracted.session is not None: + response = extracted.session.act( + slot_copy, + strategy=None, + format=response_model, + model_options=extracted.model_options, + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + response, context = mfuncs.act( + slot_copy, + extracted.context, + extracted.backend, + strategy=None, + format=response_model, + model_options=extracted.model_options, + ) + + function_response: FunctionResponse[R] = response_model.model_validate_json( + response.value # type: ignore + ) + + if context is None: + return function_response.result + else: + return function_response.result, context + + class AsyncGenerativeSlot(GenerativeSlot, Generic[P, R]): """A generative slot component that generates asynchronously and returns a coroutine.""" + @overload + def __call__( + self, + context: Context, + backend: Backend, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> Coroutine[Any, Any, tuple[R, Context]]: ... + + @overload def __call__( self, - m: MelleaSession | None = None, + m: MelleaSession, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, - ) -> Coroutine[Any, Any, R]: + ) -> Coroutine[Any, Any, R]: ... + + def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R]: """Call the async generative slot. Args: - m: MelleaSession: A mellea session (optional, uses context if None) + m: MelleaSession: A mellea session (optional; must set context and backend if None) + context: the Context object (optional; session must be set if None) + backend: the backend used for generation (optional: session must be set if None) model_options: Model options to pass to the backend. *args: Additional args to be passed to the func. - **kwargs: Additional Kwargs to be passed to the func + **kwargs: Additional Kwargs to be passed to the func. Returns: Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function """ - if m is None: - m = get_session() + extracted = self.extract_args_and_kwargs(*args, **kwargs) + slot_copy = deepcopy(self) - arguments = bind_function_arguments(self._function._func, *args, **kwargs) + arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) if arguments: for key, val in arguments.items(): annotation = get_annotation(slot_copy._function._func, key, val) @@ -241,20 +406,50 @@ def __call__( # AsyncGenerativeSlots are used with async functions. In order to support that behavior, # they must return a coroutine object. - async def __async_call__() -> R: + async def __async_call__() -> tuple[R, Context] | R: + response, context = None, None + # Use the async act func so that control flow doesn't get stuck here in async event loops. - response = await m.aact( - slot_copy, format=response_model, model_options=model_options - ) + if extracted.session is not None: + response = await extracted.session.aact( + slot_copy, + strategy=None, + format=response_model, + model_options=extracted.model_options, + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + response, context = await mfuncs.aact( + slot_copy, + extracted.context, + extracted.backend, + strategy=None, + format=response_model, + model_options=extracted.model_options, + ) function_response: FunctionResponse[R] = response_model.model_validate_json( response.value # type: ignore ) - return function_response.result + + if context is None: + return function_response.result + else: + return function_response.result, context return __async_call__() +@overload +def generative(func: Callable[P, Awaitable[R]]) -> AsyncGenerativeSlot[P, R]: ... # type: ignore + + +@overload +def generative(func: Callable[P, R]) -> SyncGenerativeSlot[P, R]: ... + + def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: """Convert a function into an AI-powered function. @@ -357,7 +552,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: if inspect.iscoroutinefunction(func): return AsyncGenerativeSlot(func) else: - return GenerativeSlot(func) + return SyncGenerativeSlot(func) # Export the decorator as the interface diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index c9695260..b015f7f4 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -2,7 +2,8 @@ import pytest from typing import Literal from mellea import generative, start_session -from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot +from mellea.stdlib.base import Context +from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, SyncGenerativeSlot @generative @@ -33,7 +34,7 @@ def test_gen_slot_output(classify_sentiment_output): def test_func(session): - assert isinstance(write_me_an_email, GenerativeSlot) and not isinstance(write_me_an_email, AsyncGenerativeSlot) + assert isinstance(write_me_an_email, SyncGenerativeSlot) write_email_component = write_me_an_email(session) assert isinstance(write_email_component, str) @@ -48,18 +49,36 @@ def test_gen_slot_logs(classify_sentiment_output, session): assert isinstance(last_prompt, dict) assert set(last_prompt.keys()) == {"role", "content", "images"} +def test_gen_slot_with_context_and_backend(session): + email, context = write_me_an_email(context=session.ctx, backend=session.backend) + assert isinstance(email, str) + assert isinstance(context, Context) + async def test_async_gen_slot(session): assert isinstance(async_write_short_sentence, AsyncGenerativeSlot) r1 = async_write_short_sentence(session, topic="cats") r2 = async_write_short_sentence(session, topic="dogs") - r3 = await async_write_short_sentence(session, topic="fish") + r3, c3 = await async_write_short_sentence(context=session.ctx, backend=session.backend, topic="fish") results = await asyncio.gather(r1, r2) assert isinstance(r3, str) + assert isinstance(c3, Context) assert len(results) == 2 +def test_duplicate_args(session): + with pytest.raises(ValueError, match="passed in multiple values"): + _ = write_me_an_email(session.ctx, backend=session.backend, context=session.ctx) # type: ignore + +def test_extra_args(session): + with pytest.raises(TypeError, match="got an unexpected keyword argument"): + _ = write_me_an_email(m=session, random_param="random_param") # type: ignore + +def test_without_required_args(): + with pytest.raises(ValueError, match="need to pass in a session or a"): + _ = write_me_an_email() # type: ignore + if __name__ == "__main__": pytest.main([__file__]) From d9be154c49b65fda17727c239ba7ba5d87dfa848 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Fri, 10 Oct 2025 16:26:00 -0400 Subject: [PATCH 2/9] test: add gen slot changes that don't work --- mellea/stdlib/genslot.py | 34 +++++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 481c3122..ce853e01 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -22,6 +22,9 @@ import mellea.stdlib.funcs as mfuncs from mellea.backends import Backend from mellea.stdlib.base import Component, Context, TemplateRepresentation +from mellea.stdlib.requirement import Requirement, reqify +from mellea.stdlib.sampling.base import RejectionSamplingStrategy +from mellea.stdlib.sampling.types import SamplingStrategy from mellea.stdlib.session import MelleaSession P = ParamSpec("P") @@ -330,7 +333,8 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: if extracted.session is not None: response = extracted.session.act( slot_copy, - strategy=None, + requirements=self._requirements, + strategy=self._strategy, format=response_model, model_options=extracted.model_options, ) @@ -425,7 +429,8 @@ async def __async_call__() -> tuple[R, Context] | R: slot_copy, extracted.context, extracted.backend, - strategy=None, + requirements=self._requirements, + strategy=self._strategy, format=response_model, model_options=extracted.model_options, ) @@ -445,10 +450,14 @@ async def __async_call__() -> tuple[R, Context] | R: @overload def generative(func: Callable[P, Awaitable[R]]) -> AsyncGenerativeSlot[P, R]: ... # type: ignore - @overload def generative(func: Callable[P, R]) -> SyncGenerativeSlot[P, R]: ... +# @overload +# # def generative(*, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2) +# # ) -> Callable[[Callable[P, R]], GenerativeSlot[P, R]]:... +# def generative(*, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2) +# ) -> Callable[[Callable[P, R]], SyncGenerativeSlot[P, R]] | Callable[[Callable[P, Awaitable[R]]], AsyncGenerativeSlot[P, R]]:... def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: """Convert a function into an AI-powered function. @@ -549,10 +558,21 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: >>> >>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?") """ - if inspect.iscoroutinefunction(func): - return AsyncGenerativeSlot(func) - else: - return SyncGenerativeSlot(func) + # Grab and remove the func if it exists in kwargs. Otherwise, it's the only arg. + def _generative(func) -> GenerativeSlot[P, R]: + if inspect.iscoroutinefunction(func): + return AsyncGenerativeSlot(func) + else: + return SyncGenerativeSlot(func) + + if func is not None: + # If there is a function passed in, we can apply the decorator immediately. + return _generative(func) + + # If no function is passed in, the decorator is being called as @generative(...); + # need to return the _generative function with parameters in its closure. This will then + # be used as the decorator. + return _generative # Export the decorator as the interface From 577622502955b8d0a141b6ece67cec4a9185007b Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Mon, 20 Oct 2025 08:30:59 -0400 Subject: [PATCH 3/9] test: gen slot changes --- mellea/stdlib/genslot.py | 81 ++++++++++++++++++++++++++-------------- 1 file changed, 52 insertions(+), 29 deletions(-) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index ce853e01..5747c659 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -164,6 +164,8 @@ class ExtractedKwargs: context: Context | None = None backend: Backend | None = None model_options: dict | None = None + requirements: list[Requirement | str] | None = None + strategy: SamplingStrategy | None = None def __init__(self): """Used to extract the mellea args and original function args.""" @@ -207,11 +209,15 @@ def extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: Returns: ExtractedKwargs: a dataclass of the required args for mellea and the original function. Either session or (backend, context) will be non-None. + + Raises: + TODO: JAL """ # Possible args for the generative slot. extracted = ExtractedKwargs() - # Args can only have mellea args. + # Args can only have Mellea args. If the Mellea args get more complicated / + # have duplicate types, use list indices rather than a match statement. for arg in args: match arg: case MelleaSession(): @@ -222,6 +228,12 @@ def extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: extracted.backend = arg case dict(): extracted.model_options = arg + case SamplingStrategy(): + extracted.strategy = arg + case list(): + extracted.requirements = arg + + # TODO: JAL; make sure model opts doesn't conflict with f_kwargs here... T = TypeVar("T") @@ -234,9 +246,9 @@ def _get_val_or_err(name: str, var: T | None, new_val: T) -> T: f"passed in multiple values of {name} to generative slot: {var}, {new_val}" ) - # Kwargs: - # - some / all of the mellea args - # - all of the function args (P.args and P.kwargs) + # Kwargs can contain + # - some / all of the Mellea args + # - all of the function args (P.kwargs); the syntax prevents passing a P.arg to the genslot for key, val in kwargs.items(): match key: case "m": @@ -253,6 +265,14 @@ def _get_val_or_err(name: str, var: T | None, new_val: T) -> T: extracted.model_options = _get_val_or_err( "model_options", extracted.model_options, val ) + case "strategy": + extracted.strategy = _get_val_or_err( + "strategy", extracted.strategy, val + ) + case "requirements": + extracted.requirements = _get_val_or_err( + "requirements", extracted.requirements, val + ) case _: extracted.f_kwargs[key] = val @@ -290,6 +310,8 @@ def __call__( self, context: Context, backend: Backend, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, @@ -299,6 +321,8 @@ def __call__( def __call__( self, m: MelleaSession, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, @@ -311,6 +335,8 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: m: MelleaSession: A mellea session (optional; must set context and backend if None) context: the Context object (optional; session must be set if None) backend: the backend used for generation (optional: session must be set if None) + requirements: A list of requirements that the genslot output can be validated against. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. *args: Additional args to be passed to the func. **kwargs: Additional Kwargs to be passed to the func. @@ -333,8 +359,8 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: if extracted.session is not None: response = extracted.session.act( slot_copy, - requirements=self._requirements, - strategy=self._strategy, + requirements=extracted.requirements, + strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, ) @@ -346,7 +372,8 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: slot_copy, extracted.context, extracted.backend, - strategy=None, + requirements=extracted.requirements, + strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, ) @@ -369,6 +396,8 @@ def __call__( self, context: Context, backend: Backend, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, @@ -378,6 +407,8 @@ def __call__( def __call__( self, m: MelleaSession, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, @@ -390,6 +421,8 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R m: MelleaSession: A mellea session (optional; must set context and backend if None) context: the Context object (optional; session must be set if None) backend: the backend used for generation (optional: session must be set if None) + requirements: A list of requirements that the genslot output can be validated against. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. *args: Additional args to be passed to the func. **kwargs: Additional Kwargs to be passed to the func. @@ -400,6 +433,8 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R extracted = self.extract_args_and_kwargs(*args, **kwargs) slot_copy = deepcopy(self) + # TODO: JAL; need to figure out where / how reqs work; if we want to keep as a part of the object, + # apply them here after the copy has happened... arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) if arguments: for key, val in arguments.items(): @@ -417,7 +452,8 @@ async def __async_call__() -> tuple[R, Context] | R: if extracted.session is not None: response = await extracted.session.aact( slot_copy, - strategy=None, + requirements=extracted.requirements, + strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, ) @@ -429,8 +465,8 @@ async def __async_call__() -> tuple[R, Context] | R: slot_copy, extracted.context, extracted.backend, - requirements=self._requirements, - strategy=self._strategy, + requirements=extracted.requirements, + strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, ) @@ -450,15 +486,12 @@ async def __async_call__() -> tuple[R, Context] | R: @overload def generative(func: Callable[P, Awaitable[R]]) -> AsyncGenerativeSlot[P, R]: ... # type: ignore + @overload def generative(func: Callable[P, R]) -> SyncGenerativeSlot[P, R]: ... -# @overload -# # def generative(*, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2) -# # ) -> Callable[[Callable[P, R]], GenerativeSlot[P, R]]:... -# def generative(*, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2) -# ) -> Callable[[Callable[P, R]], SyncGenerativeSlot[P, R]] | Callable[[Callable[P, Awaitable[R]]], AsyncGenerativeSlot[P, R]]:... +# TODO: JAL Investigate changing genslots to functions and see if it fixes the defaults being populated. def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: """Convert a function into an AI-powered function. @@ -559,20 +592,10 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: >>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?") """ # Grab and remove the func if it exists in kwargs. Otherwise, it's the only arg. - def _generative(func) -> GenerativeSlot[P, R]: - if inspect.iscoroutinefunction(func): - return AsyncGenerativeSlot(func) - else: - return SyncGenerativeSlot(func) - - if func is not None: - # If there is a function passed in, we can apply the decorator immediately. - return _generative(func) - - # If no function is passed in, the decorator is being called as @generative(...); - # need to return the _generative function with parameters in its closure. This will then - # be used as the decorator. - return _generative + if inspect.iscoroutinefunction(func): + return AsyncGenerativeSlot(func) + else: + return SyncGenerativeSlot(func) # Export the decorator as the interface From 6d471d8af5d3f3e6612c894e1f76a60eb18f7145 Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Mon, 20 Oct 2025 10:44:04 -0400 Subject: [PATCH 4/9] test: notes --- mellea/stdlib/genslot.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 5747c659..2e50bf5d 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -435,6 +435,9 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R slot_copy = deepcopy(self) # TODO: JAL; need to figure out where / how reqs work; if we want to keep as a part of the object, # apply them here after the copy has happened... + # need to change the template; add to docstring using postconditions: + # Postconditions: + # - The input 'data' list will be sorted in ascending order. arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) if arguments: for key, val in arguments.items(): From 04161dec2cf3d26a16a82d66fb28b0dee66b856c Mon Sep 17 00:00:00 2001 From: jakelorocco Date: Mon, 20 Oct 2025 13:50:45 -0400 Subject: [PATCH 5/9] todo: continue work on arg extraction; search for todos; add tests --- mellea/stdlib/genslot.py | 80 +++++++++++++++++++++++++++++++++++++++- 1 file changed, 79 insertions(+), 1 deletion(-) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 2e50bf5d..018d233c 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -164,6 +164,7 @@ class ExtractedKwargs: context: Context | None = None backend: Backend | None = None model_options: dict | None = None + precondition_requirements: list[Requirement | str] | None = None requirements: list[Requirement | str] | None = None strategy: SamplingStrategy | None = None @@ -185,6 +186,10 @@ def __init__(self, func: Callable[P, R]): self._arguments: list[Argument] = [] functools.update_wrapper(self, func) + # Set when calling the decorated func. + self.precondition_requiremetns: list[Requirement] = [] + self.requirements: list[Requirement] = [] + @abc.abstractmethod def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: """Call the generative slot. @@ -202,6 +207,37 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: """ ... + @staticmethod + def test_extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: + def _session_extract_args_and_kwargs( + session: MelleaSession | None = None, + context: Context | None = None, + backend: Backend | None = None, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, + model_options: dict | None = None, + **kwargs, + ): + extracted = ExtractedKwargs() + extracted.session = session + extracted.context = context + extracted.backend = backend + extracted.precondition_requirements = precondition_requirements + extracted.requirements = requirements + extracted.strategy = strategy + extracted.model_options = model_options + extracted.f_kwargs = kwargs + return extracted + + # Determine which overload was used if args are passed in. + if len(args) > 0: + # TODO: JAL. go back to two funcs to do this. Makes arg parsing easier here? or only check if there's a session, else everything else is good? + # need to make first arg of the array session if there's args and no session kwarg?... + # need two funcs, but otherwise all the other args will be offset incorrectly? maybe we can handle that here as well. + if isinstance(args[0], MelleaSession): + ... + @staticmethod def extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: """Takes a mix of args and kwargs for both the generative slot and the original function and extracts them. @@ -213,12 +249,37 @@ def extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: Raises: TODO: JAL """ + # TODO: JAL write tests for this function... # Possible args for the generative slot. extracted = ExtractedKwargs() # Args can only have Mellea args. If the Mellea args get more complicated / # have duplicate types, use list indices rather than a match statement. - for arg in args: + num_args = len(args) + cur_arg = 0 + + # First and second args will differ depending on the overload. Check for session vs (backend, context). + # context: Context, + # backend: Backend, + # precondition_requirements: list[Requirement | str] | None = None, + # requirements: list[Requirement | str] | None = None, + # strategy: SamplingStrategy | None = None, + # model_options: dict | None = None, + if num_args > 0: + if isinstance(args[cur_arg], MelleaSession): + extracted.session = args[cur_arg] + elif isinstance(args[cur_arg], Context): + extracted.session = args[cur_arg] + + if num_args > 1: + if ...: + ... + else: + raise ValueError( + f"incorrect arg passed to generative slot function: {args[cur_arg]}" + ) + + for i, arg in enumerate(args): match arg: case MelleaSession(): extracted.session = arg @@ -273,6 +334,12 @@ def _get_val_or_err(name: str, var: T | None, new_val: T) -> T: extracted.requirements = _get_val_or_err( "requirements", extracted.requirements, val ) + case "precondition_requirements": + extracted.precondition_requirements = _get_val_or_err( + "precondition_requirements", + extracted.precondition_requirements, + val, + ) case _: extracted.f_kwargs[key] = val @@ -310,6 +377,7 @@ def __call__( self, context: Context, backend: Backend, + precondition_requirements: list[Requirement | str] | None = None, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = None, model_options: dict | None = None, @@ -321,6 +389,7 @@ def __call__( def __call__( self, m: MelleaSession, + precondition_requirements: list[Requirement | str] | None = None, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = None, model_options: dict | None = None, @@ -335,6 +404,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: m: MelleaSession: A mellea session (optional; must set context and backend if None) context: the Context object (optional; session must be set if None) backend: the backend used for generation (optional: session must be set if None) + precondition_requirements: A list of requirements that the genslot inputs are validated against; raises an err if not met. requirements: A list of requirements that the genslot output can be validated against. strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. @@ -396,6 +466,7 @@ def __call__( self, context: Context, backend: Backend, + precondition_requirements: list[Requirement | str] | None = None, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = None, model_options: dict | None = None, @@ -407,6 +478,7 @@ def __call__( def __call__( self, m: MelleaSession, + precondition_requirements: list[Requirement | str] | None = None, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = None, model_options: dict | None = None, @@ -421,6 +493,7 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R m: MelleaSession: A mellea session (optional; must set context and backend if None) context: the Context object (optional; session must be set if None) backend: the backend used for generation (optional: session must be set if None) + precondition_requirements: A list of requirements that the genslot inputs are validated against; raises an err if not met. requirements: A list of requirements that the genslot output can be validated against. strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. @@ -429,10 +502,15 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R Returns: Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function + + Raises: + # TODO: JAL. Change here and for all other defs; should be precondition requirements err and pydantic model validation fails """ extracted = self.extract_args_and_kwargs(*args, **kwargs) slot_copy = deepcopy(self) + slot_copy.requirements = extracted.requirements + slot_copy.precondition_requiremetns = extracted.precondition_requirements # TODO: JAL; need to figure out where / how reqs work; if we want to keep as a part of the object, # apply them here after the copy has happened... # need to change the template; add to docstring using postconditions: From c92375ee6710883f463f2a3e3c45b49fff2b23f4 Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Wed, 5 Nov 2025 11:27:44 -0500 Subject: [PATCH 6/9] feat: Change arg/kwarg extraction for gen slots; prevent passing in decorated func params as positional args --- mellea/stdlib/genslot.py | 269 ++++++++++++----------------- test/stdlib_basics/test_genslot.py | 110 ++++++++++-- 2 files changed, 210 insertions(+), 169 deletions(-) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 4a1239e2..d3dc5dec 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -5,7 +5,7 @@ import inspect from collections.abc import Awaitable, Callable, Coroutine from copy import deepcopy -from dataclasses import dataclass +from dataclasses import dataclass, fields from typing import ( Any, Generic, @@ -156,23 +156,40 @@ def __init__(self, func: Callable): @dataclass -class ExtractedKwargs: - """Used to extract the mellea args and original function args.""" +class ExtractedArgs: + """Used to extract the mellea args and original function args. See @generative decorator for additional notes on these fields. - f_kwargs: dict - session: MelleaSession | None = None + These args must match those allowed by any overload of GenerativeSlot.__call__. + """ + + f_args: tuple[Any, ...] + """*args from the original function, used to detect incorrectly passed args to generative slots""" + + f_kwargs: dict[str, Any] + """**kwargs from the original function""" + + m: MelleaSession | None = None context: Context | None = None backend: Backend | None = None model_options: dict | None = None + strategy: SamplingStrategy | None = None + precondition_requirements: list[Requirement | str] | None = None + """requirements used to check the input""" + requirements: list[Requirement | str] | None = None - strategy: SamplingStrategy | None = None + """requirements used to check the output""" def __init__(self): """Used to extract the mellea args and original function args.""" + self.f_args = tuple() self.f_kwargs = {} +_disallowed_param_names = [field.name for field in fields(ExtractedArgs())] +"""A list of parameter names used by Mellea. Cannot use these in functions decorated with @generative.""" + + class GenerativeSlot(Component, Generic[P, R]): """A generative slot component.""" @@ -181,7 +198,21 @@ def __init__(self, func: Callable[P, R]): Args: func: A callable function + + Raises: + ValueError: if the decorated function has a parameter name used by generative slots """ + sig = inspect.signature(func) + problematic_param_names: list[str] = [] + for param in sig.parameters.keys(): + if param in _disallowed_param_names: + problematic_param_names.append(param) + + if len(problematic_param_names): + raise ValueError( + f"cannot create a generative slot with disallowed parameter names: {problematic_param_names}" + ) + self._function = Function(func) self._arguments: list[Argument] = [] functools.update_wrapper(self, func) @@ -192,164 +223,83 @@ def __init__(self, func: Callable[P, R]): @abc.abstractmethod def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: - """Call the generative slot. + """Call the generative slot. See subclasses for more information.""" + ... - Args: - m: MelleaSession: A mellea session (optional; must set context and backend if None) - context: the Context object (optional; session must be set if None) - backend: the backend used for generation (optional: session must be set if None) - model_options: Model options to pass to the backend. - *args: Additional args to be passed to the func. - **kwargs: Additional Kwargs to be passed to the func. + @staticmethod + def extract_args_and_kwargs(*args, **kwargs) -> ExtractedArgs: + """Takes a mix of args and kwargs for both the generative slot and the original function and extracts them. Ensures the original function's args are all kwargs. Returns: - R: an object with the original return type of the function + ExtractedArgs: a dataclass of the required args for mellea and the original function. + Either session or (backend, context) will be non-None. + + Raises: + TypeError: if any of the original function's parameters were passed as positional args """ - ... - @staticmethod - def test_extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: def _session_extract_args_and_kwargs( - session: MelleaSession | None = None, - context: Context | None = None, - backend: Backend | None = None, + m: MelleaSession, precondition_requirements: list[Requirement | str] | None = None, requirements: list[Requirement | str] | None = None, strategy: SamplingStrategy | None = None, model_options: dict | None = None, + *args, **kwargs, ): - extracted = ExtractedKwargs() - extracted.session = session + """Helper function for extracting args. Used when a session is passed.""" + extracted = ExtractedArgs() + extracted.m = m + extracted.precondition_requirements = precondition_requirements + extracted.requirements = requirements + extracted.strategy = strategy + extracted.model_options = model_options + extracted.f_args = args + extracted.f_kwargs = kwargs + return extracted + + def _context_backend_extract_args_and_kwargs( + context: Context, + backend: Backend, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, + model_options: dict | None = None, + *args, + **kwargs, + ): + """Helper function for extracting args. Used when a context and a backend are passed.""" + extracted = ExtractedArgs() extracted.context = context extracted.backend = backend extracted.precondition_requirements = precondition_requirements extracted.requirements = requirements extracted.strategy = strategy extracted.model_options = model_options + extracted.f_args = args extracted.f_kwargs = kwargs return extracted - # Determine which overload was used if args are passed in. + # Determine which overload was used: + # - if there's args, the first arg must either be a `MelleaSession` or a `Context` + # - otherwise, just check the kwargs for a "m" that is type `MelleaSession` + using_session_overload = False if len(args) > 0: - # TODO: JAL. go back to two funcs to do this. Makes arg parsing easier here? or only check if there's a session, else everything else is good? - # need to make first arg of the array session if there's args and no session kwarg?... - # need two funcs, but otherwise all the other args will be offset incorrectly? maybe we can handle that here as well. - if isinstance(args[0], MelleaSession): - ... - - @staticmethod - def extract_args_and_kwargs(*args, **kwargs) -> ExtractedKwargs: - """Takes a mix of args and kwargs for both the generative slot and the original function and extracts them. - - Returns: - ExtractedKwargs: a dataclass of the required args for mellea and the original function. - Either session or (backend, context) will be non-None. - - Raises: - TODO: JAL - """ - # TODO: JAL write tests for this function... - # Possible args for the generative slot. - extracted = ExtractedKwargs() - - # Args can only have Mellea args. If the Mellea args get more complicated / - # have duplicate types, use list indices rather than a match statement. - num_args = len(args) - cur_arg = 0 - - # First and second args will differ depending on the overload. Check for session vs (backend, context). - # context: Context, - # backend: Backend, - # precondition_requirements: list[Requirement | str] | None = None, - # requirements: list[Requirement | str] | None = None, - # strategy: SamplingStrategy | None = None, - # model_options: dict | None = None, - if num_args > 0: - if isinstance(args[cur_arg], MelleaSession): - extracted.session = args[cur_arg] - elif isinstance(args[cur_arg], Context): - extracted.session = args[cur_arg] - - if num_args > 1: - if ...: - ... - else: - raise ValueError( - f"incorrect arg passed to generative slot function: {args[cur_arg]}" - ) + possible_session = args[0] + else: + possible_session = kwargs.get("m", None) + if isinstance(possible_session, MelleaSession): + using_session_overload = True - for i, arg in enumerate(args): - match arg: - case MelleaSession(): - extracted.session = arg - case Context(): - extracted.context = arg - case Backend(): - extracted.backend = arg - case dict(): - extracted.model_options = arg - case SamplingStrategy(): - extracted.strategy = arg - case list(): - extracted.requirements = arg - - # TODO: JAL; make sure model opts doesn't conflict with f_kwargs here... - - T = TypeVar("T") - - def _get_val_or_err(name: str, var: T | None, new_val: T) -> T: - """Returns the new_value if the original var is None, else raises a ValueError.""" - if var is None: - return new_val - else: - raise ValueError( - f"passed in multiple values of {name} to generative slot: {var}, {new_val}" - ) + # Call the appropriate function and let python handle the arg/kwarg extraction. + if using_session_overload: + extracted = _session_extract_args_and_kwargs(*args, **kwargs) + else: + extracted = _context_backend_extract_args_and_kwargs(*args, **kwargs) - # Kwargs can contain - # - some / all of the Mellea args - # - all of the function args (P.kwargs); the syntax prevents passing a P.arg to the genslot - for key, val in kwargs.items(): - match key: - case "m": - extracted.session = _get_val_or_err("m", extracted.session, val) - case "context": - extracted.context = _get_val_or_err( - "context", extracted.context, val - ) - case "backend": - extracted.backend = _get_val_or_err( - "backend", extracted.backend, val - ) - case "model_options": - extracted.model_options = _get_val_or_err( - "model_options", extracted.model_options, val - ) - case "strategy": - extracted.strategy = _get_val_or_err( - "strategy", extracted.strategy, val - ) - case "requirements": - extracted.requirements = _get_val_or_err( - "requirements", extracted.requirements, val - ) - case "precondition_requirements": - extracted.precondition_requirements = _get_val_or_err( - "precondition_requirements", - extracted.precondition_requirements, - val, - ) - case _: - extracted.f_kwargs[key] = val - - # Need to check that either session is set or both backend and context are set; - # model_options can be None. - if extracted.session is None and ( - extracted.backend is None or extracted.context is None - ): - raise ValueError( - f"need to pass in a session or a (backend and context) to generative slot; got session({extracted.session}), backend({extracted.backend}), context({extracted.context})" + if len(extracted.f_args) > 0: + raise TypeError( + "generative slots do not accept positional args from the decorated function; use keyword args instead" ) return extracted @@ -401,8 +351,8 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: """Call the generative slot. Args: - m: MelleaSession: A mellea session (optional; must set context and backend if None) - context: the Context object (optional; session must be set if None) + m: MelleaSession: A mellea session (optional: must set context and backend if None) + context: the Context object (optional: session must be set if None) backend: the backend used for generation (optional: session must be set if None) precondition_requirements: A list of requirements that the genslot inputs are validated against; raises an err if not met. requirements: A list of requirements that the genslot output can be validated against. @@ -413,6 +363,9 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: Returns: R: an object with the original return type of the function + + Raises: + TypeError: if any of the original function's parameters were passed as positional args """ extracted = self.extract_args_and_kwargs(*args, **kwargs) @@ -426,8 +379,8 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: response_model = create_response_format(self._function._func) response, context = None, None - if extracted.session is not None: - response = extracted.session.act( + if extracted.m is not None: + response = extracted.m.act( slot_copy, requirements=extracted.requirements, strategy=extracted.strategy, @@ -490,8 +443,8 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R """Call the async generative slot. Args: - m: MelleaSession: A mellea session (optional; must set context and backend if None) - context: the Context object (optional; session must be set if None) + m: MelleaSession: A mellea session (optional: must set context and backend if None) + context: the Context object (optional: session must be set if None) backend: the backend used for generation (optional: session must be set if None) precondition_requirements: A list of requirements that the genslot inputs are validated against; raises an err if not met. requirements: A list of requirements that the genslot output can be validated against. @@ -504,7 +457,7 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function Raises: - # TODO: JAL. Change here and for all other defs; should be precondition requirements err and pydantic model validation fails + TypeError: if any of the original function's parameters were passed as positional args """ extracted = self.extract_args_and_kwargs(*args, **kwargs) @@ -530,8 +483,8 @@ async def __async_call__() -> tuple[R, Context] | R: response, context = None, None # Use the async act func so that control flow doesn't get stuck here in async event loops. - if extracted.session is not None: - response = await extracted.session.aact( + if extracted.m is not None: + response = await extracted.m.aact( slot_copy, requirements=extracted.requirements, strategy=extracted.strategy, @@ -582,7 +535,11 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: that function's behavior. The output is guaranteed to match the return type annotation using structured outputs and automatic validation. - Note: Works with async functions as well. + Notes: + - Works with async functions as well. + - Must pass all parameters for the original function as keyword args. + - Most python type-hinters will not show the default values but will correctly infer them; + this means that you can set default values in the decorated function and the only necessary values will be a session or a (context, backend). Tip: Write the function and docstring in the most Pythonic way possible, not like a prompt. This ensures the function is well-documented, easily understood, @@ -597,7 +554,9 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: original function's signature and docstring. Raises: - ValidationError: if the generated output cannot be parsed into the expected return type. Typically happens when the token limit for the generated output results in invalid json. + ValueError: (raised by @generative) if the decorated function has a parameter name used by generative slots + ValidationError: (raised when calling the generative slot) if the generated output cannot be parsed into the expected return type. Typically happens when the token limit for the generated output results in invalid json. + TypeError: (raised when calling the generative slot) if any of the original function's parameters were passed as positional args Examples: >>> from mellea import generative, start_session @@ -607,7 +566,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... '''Generate a concise summary of the input text.''' ... ... >>> - >>> summary = summarize_text(session, "Long text...", max_words=30) + >>> summary = summarize_text(session, text="Long text...", max_words=30) >>> from typing import List >>> from dataclasses import dataclass @@ -631,7 +590,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... ''' ... ... >>> - >>> tasks = await create_project_tasks(session, "Build a web app", 5) + >>> tasks = await create_project_tasks(session, project_desc="Build a web app", count=5) >>> @generative ... def analyze_code_quality(code: str) -> Dict[str, Any]: @@ -651,7 +610,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: >>> >>> analysis = analyze_code_quality( ... session, - ... "def factorial(n): return n * factorial(n-1)", + ... code="def factorial(n): return n * factorial(n-1)", ... model_options={"temperature": 0.3} ... ) @@ -673,7 +632,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... ''' ... ... >>> - >>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?") + >>> reasoning = generate_chain_of_thought(session, problem="How to optimize a slow database query?") """ # Grab and remove the func if it exists in kwargs. Otherwise, it's the only arg. if inspect.iscoroutinefunction(func): diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index b015f7f4..77a35da9 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -2,9 +2,24 @@ import pytest from typing import Literal from mellea import generative, start_session -from mellea.stdlib.base import Context +from mellea.backends.model_ids import META_LLAMA_3_2_1B +from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.base import ChatContext, Context from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, SyncGenerativeSlot - +from mellea.stdlib.sampling.base import RejectionSamplingStrategy +from mellea.stdlib.session import MelleaSession + +@pytest.fixture(scope="module") +def backend(gh_run: int): + """Shared backend.""" + if gh_run == 1: + return OllamaModelBackend( + model_id=META_LLAMA_3_2_1B.ollama_name, # type: ignore + ) + else: + return OllamaModelBackend( + model_id="granite3.3:8b", + ) @generative def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... @@ -67,18 +82,85 @@ async def test_async_gen_slot(session): assert isinstance(c3, Context) assert len(results) == 2 -def test_duplicate_args(session): - with pytest.raises(ValueError, match="passed in multiple values"): - _ = write_me_an_email(session.ctx, backend=session.backend, context=session.ctx) # type: ignore - -def test_extra_args(session): - with pytest.raises(TypeError, match="got an unexpected keyword argument"): - _ = write_me_an_email(m=session, random_param="random_param") # type: ignore - -def test_without_required_args(): - with pytest.raises(ValueError, match="need to pass in a session or a"): - _ = write_me_an_email() # type: ignore - +@pytest.mark.parametrize( + "arg_choices,kwarg_choices,errs", + [ + pytest.param(["m"], ["func1", "func2", "func3"], False, id="session"), + pytest.param(["context"], ["backend"], False, id="context and backend"), + pytest.param(["backend"], ["func1", "func2", "func3"], True, id="backend without context"), + pytest.param(["m"], ["m"], True, id="duplicate arg and kwarg"), + pytest.param(["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], [], True, id="original func args as positional args"), + pytest.param([], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs"), + pytest.param([], ["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], False, id="all kwargs"), + pytest.param([], ["func1", "m", "func2", "requirements", "func3"], False, id="interspersed kwargs"), + pytest.param([], [], True, id="missing required args") + ] +) +def test_arg_extraction(backend, arg_choices, kwarg_choices, errs): + """Tests the internal extract_args_and_kwargs function. + + This function has to test a large number of input combinations; as a result, + it uses a parameterization scheme. It takes a list and a dict. Each contains + strings corresponding to the possible args/kwargs below. Order matters in the list. + See the param id for an idea of what the test does. + + Python should catch most of these issues itself. We have to manually raise an exception for + the arguments of the original function being positional. + """ + + # List of all needed values. + backend = backend + ctx = ChatContext() + session = MelleaSession(backend, ctx) + precondition_requirements = ["precondition"] + requirements = None + strategy = RejectionSamplingStrategy() + model_options = {"test": 1} + func1 = 1 + func2 = True + func3 = "func3" + + # Lookup table by name. + vals = { + "m": session, + "backend": backend, + "context": ctx, + "precondition_requirements": precondition_requirements, + "requirements": requirements, + "strategy": strategy, + "model_options": model_options, + "func1": func1, + "func2": func2, + "func3": func3, + } + + args = [] + for arg in arg_choices: + args.append(vals[arg]) + + kwargs = {} + for kwarg in kwarg_choices: + kwargs[kwarg] = vals[kwarg] + + # Run the extraction and check for the (un-)expected exception. + found_err = False + err = None + try: + GenerativeSlot.extract_args_and_kwargs(*args, **kwargs) + except Exception as e: + found_err = True + err = e + + if errs: + assert found_err, "expected an exception and got none" + else: + assert not found_err, f"got unexpected err: {err}" + +def test_disallowed_parameter_names(): + with pytest.raises(ValueError): + @generative + def test(backend): + ... if __name__ == "__main__": pytest.main([__file__]) From ebbfbb6c129f362f5947882b94912275e0e5f5c5 Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Wed, 5 Nov 2025 17:06:09 -0500 Subject: [PATCH 7/9] fix: stop passing along format when validating in sampling strategies --- mellea/stdlib/sampling/base.py | 2 +- mellea/stdlib/sampling/best_of_n.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 6520a7da..e5b5bc5e 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -166,7 +166,7 @@ async def sample( context=result_ctx, backend=backend, output=result, - format=format, + format=None, model_options=model_options, # tool_calls=tool_calls # Don't support using tool calls in validation strategies. ) diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index 59c402b2..0f8aaffb 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -132,7 +132,7 @@ async def sample( context=result_ctx, backend=backend, output=result, - format=format, + format=None, model_options=model_options, input=next_action._description, # type: ignore # tool_calls=tool_calls # Don't support using tool calls in validation strategies. From 238fdb46a24f5dbfbe1bd2791e9691772753860e Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Fri, 7 Nov 2025 17:27:34 -0500 Subject: [PATCH 8/9] feat: add requirements / preconditions to gen slots --- .../generative_slots/generative_slots.py | 3 +- .../generative_slots_with_requirements.py | 69 ++++ mellea/stdlib/genslot.py | 301 +++++++++++++----- .../default/ArgPreconditionRequirement.jinja2 | 17 + .../prompts/default/GenerativeSlot.jinja2 | 18 +- test/stdlib_basics/test_genslot.py | 28 +- 6 files changed, 356 insertions(+), 80 deletions(-) create mode 100644 docs/examples/generative_slots/generative_slots_with_requirements.py create mode 100644 mellea/templates/prompts/default/ArgPreconditionRequirement.jinja2 diff --git a/docs/examples/generative_slots/generative_slots.py b/docs/examples/generative_slots/generative_slots.py index 1b77de44..2e1f5e40 100644 --- a/docs/examples/generative_slots/generative_slots.py +++ b/docs/examples/generative_slots/generative_slots.py @@ -20,6 +20,7 @@ def generate_summary(text: str) -> str: print("Output sentiment is : ", sentiment_component) summary = generate_summary( + m=m, text=""" The eagle rays are a group of cartilaginous fishes in the family Myliobatidae, consisting mostly of large species living in the open ocean rather than on the sea bottom. @@ -28,6 +29,6 @@ def generate_summary(text: str) -> str: surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies. They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to 5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan. - """ + """, ) print("Generated summary is :", summary) diff --git a/docs/examples/generative_slots/generative_slots_with_requirements.py b/docs/examples/generative_slots/generative_slots_with_requirements.py new file mode 100644 index 00000000..28995e0d --- /dev/null +++ b/docs/examples/generative_slots/generative_slots_with_requirements.py @@ -0,0 +1,69 @@ +from typing import Literal + +from mellea import generative, start_session +from mellea.stdlib.genslot import PreconditionException +from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + +@generative +def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: + """Classify the sentiment of the text.""" + ... + + +if __name__ == "__main__": + m = start_session() + + # Add preconditions and requirements. + sentiment_component = classify_sentiment( + m, + text="I love this!", + + # Preconditions are only checked with basic validation. Don't use the strategy. + precondition_requirements=["the text arg should be less than 100 words"], + + # Reqs to use with the strategy. You could also just remove "unknown" from the structured output for this. + requirements=["avoid classifying the sentiment as unknown"], + strategy=RejectionSamplingStrategy() # Must specify a strategy for gen slots + ) + + print(f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]["content"]}\n```") # type: ignore + # Prompt to the model looked like: + # ``` + # Your task is to imitate the output of the following function for the given arguments. + # Reply Nothing else but the output of the function. + + # Function: + # def classify_sentiment(text: str) -> Literal['positive', 'negative', 'unknown']: + # """Classify the sentiment of the text. + + # Postconditions: + # - avoid classifying the sentiment as unknown + # """ + + # Arguments: + # - text: "I love this!" (type: ) + # ``` + + print("\nOutput sentiment is:", sentiment_component) + + + # We can also force a precondition failure. + try: + sentiment_component = classify_sentiment( + m, + text="I hate this!", + + # Requirement always fails to validate given the lambda. + precondition_requirements=[ + Requirement("the text arg should be only one word", validation_fn=simple_validate(lambda x: (False, "Forced to fail!"))) + ], + ) + except PreconditionException as e: + print(f"exception: {str(e)}") + + # Look at why the precondition validation failed. + print("Failure reasons:") + for val_result in e.validation: + print("-", val_result.reason) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index d3dc5dec..9f1a089b 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -4,26 +4,23 @@ import functools import inspect from collections.abc import Awaitable, Callable, Coroutine -from copy import deepcopy +from copy import copy, deepcopy from dataclasses import dataclass, fields -from typing import ( - Any, - Generic, - ParamSpec, - TypedDict, - TypeVar, - Union, - get_type_hints, - overload, -) +from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints, overload from pydantic import BaseModel, Field, create_model import mellea.stdlib.funcs as mfuncs from mellea.backends import Backend -from mellea.stdlib.base import Component, Context, TemplateRepresentation -from mellea.stdlib.requirement import Requirement, reqify -from mellea.stdlib.sampling.base import RejectionSamplingStrategy +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import ( + CBlock, + Component, + Context, + ModelOutputThunk, + TemplateRepresentation, +) +from mellea.stdlib.requirement import Requirement, ValidationResult, reqify from mellea.stdlib.sampling.types import SamplingStrategy from mellea.stdlib.session import MelleaSession @@ -76,6 +73,81 @@ class ArgumentDict(TypedDict): value: str | None +class Argument: + """An Argument Component.""" + + def __init__( + self, + annotation: str | None = None, + name: str | None = None, + value: str | None = None, + ): + """An Argument Component.""" + self._argument_dict: ArgumentDict = { + "name": name, + "annotation": annotation, + "value": value, + } + + +class Arguments(CBlock): + def __init__(self, arguments: list[Argument]): + """Create a textual representation of a list of arguments.""" + # Make meta the original list of arguments and create a list of textual representations. + meta: dict[str, Any] = {} + text_args = [] + for arg in arguments: + assert arg._argument_dict["name"] is not None + meta[arg._argument_dict["name"]] = arg + text_args.append( + f"- {arg._argument_dict['name']}: {arg._argument_dict['value']} (type: {arg._argument_dict['annotation']})" + ) + + super().__init__("\n".join(text_args), meta) + + +class ArgPreconditionRequirement(Requirement): + """Specific requirement with template for validating precondition requirements against a set of args.""" + + def __init__(self, req: Requirement): + """Can only be instantiated from existing requirements. All function calls are delegated to the underlying requirement.""" + self.req = req + + def __getattr__(self, name): + return getattr(self.req, name) + + def __copy__(self): + return ArgPreconditionRequirement(req=self.req) + + def __deepcopy__(self, memo): + return ArgPreconditionRequirement(deepcopy(self.req, memo)) + + +class PreconditionException(Exception): + """Exception raised when validation fails for a generative slot's arguments.""" + + def __init__( + self, message: str, validation_results: list[ValidationResult] + ) -> None: + """Exception raised when validation fails for a generative slot's arguments. + + Args: + message: the error message + validation_results: the list of validation results from the failed preconditions + """ + super().__init__(message) + self.validation = validation_results + + +class Function: + """A Function Component.""" + + def __init__(self, func: Callable): + """A Function Component.""" + self._func: Callable = func + self._function_dict: FunctionDict = describe_function(func) + + def describe_function(func: Callable) -> FunctionDict: """Generates a FunctionDict given a function. @@ -92,22 +164,30 @@ def describe_function(func: Callable) -> FunctionDict: } -def get_annotation(func: Callable, key: str, val: Any) -> str: - """Returns a annotated list of arguments for a given function and list of arguments. +def get_argument(func: Callable, key: str, val: Any) -> Argument: + """Returns an argument given a parameter. + + Note: Performs additional formatting for string objects, putting them in quotes. Args: func : Callable Function - key : Arg keys - val : Arg Values + key : Arg key + val : Arg value Returns: - str: An annotated string for a given func. + Argument: an argument object representing the given parameter. """ sig = inspect.signature(func) param = sig.parameters.get(key) if param and param.annotation is not inspect.Parameter.empty: - return str(param.annotation) - return str(type(val)) + param_type = param.annotation + else: + param_type = type(val) + + if param_type is str: + val = f'"{val!s}"' + + return Argument(str(param_type), key, val) def bind_function_arguments( @@ -129,32 +209,6 @@ def bind_function_arguments( return dict(bound_arguments.arguments) -class Argument: - """An Argument Component.""" - - def __init__( - self, - annotation: str | None = None, - name: str | None = None, - value: str | None = None, - ): - """An Argument Component.""" - self._argument_dict: ArgumentDict = { - "name": name, - "annotation": annotation, - "value": value, - } - - -class Function: - """A Function Component.""" - - def __init__(self, func: Callable): - """A Function Component.""" - self._func: Callable = func - self._function_dict: FunctionDict = describe_function(func) - - @dataclass class ExtractedArgs: """Used to extract the mellea args and original function args. See @generative decorator for additional notes on these fields. @@ -214,7 +268,7 @@ def __init__(self, func: Callable[P, R]): ) self._function = Function(func) - self._arguments: list[Argument] = [] + self._arguments: Arguments | None = None functools.update_wrapper(self, func) # Set when calling the decorated func. @@ -314,7 +368,14 @@ def format_for_llm(self) -> TemplateRepresentation: obj=self, args={ "function": self._function._function_dict, - "arguments": [a._argument_dict for a in self._arguments], + "arguments": self._arguments, + "requirements": [ + r.description + for r in self.requirements + if r.description is not None + and r.description != "" + and not r.check_only + ], # Same conditions on requirements as in instruction. }, tools=None, template_order=["*", "GenerativeSlot"], @@ -354,7 +415,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: m: MelleaSession: A mellea session (optional: must set context and backend if None) context: the Context object (optional: session must be set if None) backend: the backend used for generation (optional: session must be set if None) - precondition_requirements: A list of requirements that the genslot inputs are validated against; raises an err if not met. + precondition_requirements: A list of requirements that the genslot inputs are validated against; does not use a sampling strategy. requirements: A list of requirements that the genslot output can be validated against. strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. @@ -362,27 +423,73 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: **kwargs: Additional Kwargs to be passed to the func. Returns: - R: an object with the original return type of the function + Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function Raises: TypeError: if any of the original function's parameters were passed as positional args + PreconditionException: if the precondition validation fails, catch the err to get the validation results """ extracted = self.extract_args_and_kwargs(*args, **kwargs) slot_copy = deepcopy(self) + if extracted.requirements is not None: + slot_copy.requirements = [reqify(r) for r in extracted.requirements] + + if extracted.precondition_requirements is not None: + slot_copy.precondition_requirements = [ + ArgPreconditionRequirement(reqify(r)) + for r in extracted.precondition_requirements + ] + arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) if arguments: + slot_args: list[Argument] = [] for key, val in arguments.items(): - annotation = get_annotation(slot_copy._function._func, key, val) - slot_copy._arguments.append(Argument(annotation, key, val)) + slot_args.append(get_argument(slot_copy._function._func, key, val)) + slot_copy._arguments = Arguments(slot_args) response_model = create_response_format(self._function._func) + # Do precondition validation first. + if slot_copy._arguments is not None: + if extracted.m is not None: + val_results = extracted.m.validate( + reqs=slot_copy.precondition_requirements, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + val_results = mfuncs.validate( + reqs=slot_copy.precondition_requirements, + context=extracted.context, + backend=extracted.backend, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + + # No retries if precondition validation fails. + if not all(bool(val_result) for val_result in val_results): + FancyLogger.get_logger().error( + "generative slot arguments did not satisfy precondition requirements" + ) + raise PreconditionException( + "generative slot arguments did not satisfy precondition requirements", + validation_results=val_results, + ) + + elif len(slot_copy.precondition_requirements) > 0: + FancyLogger.get_logger().warning( + "calling a generative slot with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" + ) + response, context = None, None if extracted.m is not None: response = extracted.m.act( slot_copy, - requirements=extracted.requirements, + requirements=slot_copy.requirements, strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, @@ -395,7 +502,7 @@ def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: slot_copy, extracted.context, extracted.backend, - requirements=extracted.requirements, + requirements=slot_copy.requirements, strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, @@ -446,7 +553,7 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R m: MelleaSession: A mellea session (optional: must set context and backend if None) context: the Context object (optional: session must be set if None) backend: the backend used for generation (optional: session must be set if None) - precondition_requirements: A list of requirements that the genslot inputs are validated against; raises an err if not met. + precondition_requirements: A list of requirements that the genslot inputs are validated against; does not use a sampling strategy. requirements: A list of requirements that the genslot output can be validated against. strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. @@ -458,35 +565,74 @@ def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R Raises: TypeError: if any of the original function's parameters were passed as positional args + PreconditionException: if the precondition validation fails, catch the err to get the validation results """ extracted = self.extract_args_and_kwargs(*args, **kwargs) slot_copy = deepcopy(self) - slot_copy.requirements = extracted.requirements - slot_copy.precondition_requirements = extracted.precondition_requirements - # TODO: JAL; need to figure out where / how reqs work; if we want to keep as a part of the object, - # apply them here after the copy has happened... - # need to change the template; add to docstring using postconditions: - # Postconditions: - # - The input 'data' list will be sorted in ascending order. + if extracted.requirements is not None: + slot_copy.requirements = [reqify(r) for r in extracted.requirements] + + if extracted.precondition_requirements is not None: + slot_copy.precondition_requirements = [ + ArgPreconditionRequirement(reqify(r)) + for r in extracted.precondition_requirements + ] + arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) if arguments: + slot_args: list[Argument] = [] for key, val in arguments.items(): - annotation = get_annotation(slot_copy._function._func, key, val) - slot_copy._arguments.append(Argument(annotation, key, val)) + slot_args.append(get_argument(slot_copy._function._func, key, val)) + slot_copy._arguments = Arguments(slot_args) response_model = create_response_format(self._function._func) # AsyncGenerativeSlots are used with async functions. In order to support that behavior, # they must return a coroutine object. async def __async_call__() -> tuple[R, Context] | R: + """Use async calls so that control flow doesn't get stuck here in async event loops.""" response, context = None, None - # Use the async act func so that control flow doesn't get stuck here in async event loops. + # Do precondition validation first. + if slot_copy._arguments is not None: + if extracted.m is not None: + val_results = await extracted.m.avalidate( + reqs=slot_copy.precondition_requirements, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + val_results = await mfuncs.avalidate( + reqs=slot_copy.precondition_requirements, + context=extracted.context, + backend=extracted.backend, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + + # No retries if precondition validation fails. + if not all(bool(val_result) for val_result in val_results): + FancyLogger.get_logger().error( + "generative slot arguments did not satisfy precondition requirements" + ) + raise PreconditionException( + "generative slot arguments did not satisfy precondition requirements", + validation_results=val_results, + ) + + elif len(slot_copy.precondition_requirements) > 0: + FancyLogger.get_logger().warning( + "calling a generative slot with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" + ) + if extracted.m is not None: response = await extracted.m.aact( slot_copy, - requirements=extracted.requirements, + requirements=slot_copy.requirements, strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, @@ -499,7 +645,7 @@ async def __async_call__() -> tuple[R, Context] | R: slot_copy, extracted.context, extracted.backend, - requirements=extracted.requirements, + requirements=slot_copy.requirements, strategy=extracted.strategy, format=response_model, model_options=extracted.model_options, @@ -525,7 +671,6 @@ def generative(func: Callable[P, Awaitable[R]]) -> AsyncGenerativeSlot[P, R]: .. def generative(func: Callable[P, R]) -> SyncGenerativeSlot[P, R]: ... -# TODO: JAL Investigate changing genslots to functions and see if it fixes the defaults being populated. def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: """Convert a function into an AI-powered function. @@ -546,6 +691,18 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: and familiar to any Python developer. The more natural and conventional your function definition, the better the AI will understand and imitate it. + The new function has the following additional args: + *m*: MelleaSession: A mellea session (optional: must set context and backend if None) + *context*: Context: the Context object (optional: session must be set if None) + *backend*: Backend: the backend used for generation (optional: session must be set if None) + *precondition_requirements*: list[Requirements | str] | None: A list of requirements that the genslot inputs are validated against; raises an err if not met. + *requirements*: list[Requirement | str] | None: A list of requirements that the genslot output can be validated against. + *strategy*: SamplingStrategy | None: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. + *model_options*: dict | None: Model options to pass to the backend. + + The requirements and validation for the generative function operate over a textual representation + of the arguments / outputs (not their python objects). + Args: func: Function with docstring and type hints. Implementation can be empty (...). @@ -557,6 +714,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ValueError: (raised by @generative) if the decorated function has a parameter name used by generative slots ValidationError: (raised when calling the generative slot) if the generated output cannot be parsed into the expected return type. Typically happens when the token limit for the generated output results in invalid json. TypeError: (raised when calling the generative slot) if any of the original function's parameters were passed as positional args + PreconditionException: (raised when calling the generative slot) if the precondition validation of the args fails; catch the exception to get the validation results Examples: >>> from mellea import generative, start_session @@ -634,12 +792,11 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: >>> >>> reasoning = generate_chain_of_thought(session, problem="How to optimize a slow database query?") """ - # Grab and remove the func if it exists in kwargs. Otherwise, it's the only arg. if inspect.iscoroutinefunction(func): return AsyncGenerativeSlot(func) else: return SyncGenerativeSlot(func) -# Export the decorator as the interface -__all__ = ["generative"] +# Export the decorator as the interface. Export the specific exception for debugging. +__all__ = ["PreconditionException", "generative"] diff --git a/mellea/templates/prompts/default/ArgPreconditionRequirement.jinja2 b/mellea/templates/prompts/default/ArgPreconditionRequirement.jinja2 new file mode 100644 index 00000000..7c96c046 --- /dev/null +++ b/mellea/templates/prompts/default/ArgPreconditionRequirement.jinja2 @@ -0,0 +1,17 @@ +Please check if the following arguments satisfy the precondition. +Reply with 'yes' if the precondition is satisfied and 'no' otherwise. +Do not include any other text in your response. + +{%- block output -%} +{% if output %} + +Arguments: +{{ output -}} +{%- endif -%} +{% endblock output%} + +{%- block description %} +{% if description %} +Precondition: {{ description -}} +{%- endif -%} +{% endblock description %} diff --git a/mellea/templates/prompts/default/GenerativeSlot.jinja2 b/mellea/templates/prompts/default/GenerativeSlot.jinja2 index 309bfc9b..8305dedc 100644 --- a/mellea/templates/prompts/default/GenerativeSlot.jinja2 +++ b/mellea/templates/prompts/default/GenerativeSlot.jinja2 @@ -2,13 +2,19 @@ Your task is to imitate the output of the following function for the given argum Reply Nothing else but the output of the function. Function: -def {{ function.name }}({{ function.signature }}) - """{{ function.docstring | default("No documentation provided.") }}""" +def {{ function.name }}{{ function.signature }}: + """{{ function.docstring | default("No documentation provided.") -}} + {% if requirements|length > 0 %} -{%- if arguments|length > 0 %} + Postconditions: + {%- for req in requirements %} + - {{ req -}} + {% endfor %} + {% endif -%} + """ + +{%- if arguments is not none %} Arguments: -{%- for arg in arguments %} -- {{ arg.name }}: {{ arg.value }} (type: {{ arg.annotation -}}) -{%- endfor -%} +{{ arguments }} {%- endif -%} \ No newline at end of file diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index 77a35da9..78f43f14 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -5,7 +5,8 @@ from mellea.backends.model_ids import META_LLAMA_3_2_1B from mellea.backends.ollama import OllamaModelBackend from mellea.stdlib.base import ChatContext, Context -from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, SyncGenerativeSlot +from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, PreconditionException, SyncGenerativeSlot +from mellea.stdlib.requirement import Requirement, simple_validate from mellea.stdlib.sampling.base import RejectionSamplingStrategy from mellea.stdlib.session import MelleaSession @@ -162,5 +163,30 @@ def test_disallowed_parameter_names(): def test(backend): ... +def test_precondition_failure(session): + with pytest.raises(PreconditionException): + classify_sentiment( + m=session, + text="hello", + precondition_requirements=[ + Requirement("forced failure", validation_fn=simple_validate(lambda x: (False, ""))) + ] + ) + +def test_requirement(session): + classify_sentiment( + m=session, + text="hello", + requirements=["req1", "req2", Requirement("req3")] + ) + +def test_with_no_args(session): + @generative + def generate_text() -> str: + """Generate text!""" + ... + + generate_text(m=session) + if __name__ == "__main__": pytest.main([__file__]) From 4cc52adbbc772b9d87ec9a06a0962c08d23f948a Mon Sep 17 00:00:00 2001 From: Jake LoRocco Date: Fri, 7 Nov 2025 22:06:31 -0500 Subject: [PATCH 9/9] fix: formatting in new gen slot example --- .../generative_slots_with_requirements.py | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) diff --git a/docs/examples/generative_slots/generative_slots_with_requirements.py b/docs/examples/generative_slots/generative_slots_with_requirements.py index 28995e0d..7304bc3c 100644 --- a/docs/examples/generative_slots/generative_slots_with_requirements.py +++ b/docs/examples/generative_slots/generative_slots_with_requirements.py @@ -19,16 +19,16 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: sentiment_component = classify_sentiment( m, text="I love this!", - # Preconditions are only checked with basic validation. Don't use the strategy. - precondition_requirements=["the text arg should be less than 100 words"], - + precondition_requirements=["the text arg should be less than 100 words"], # Reqs to use with the strategy. You could also just remove "unknown" from the structured output for this. requirements=["avoid classifying the sentiment as unknown"], - strategy=RejectionSamplingStrategy() # Must specify a strategy for gen slots + strategy=RejectionSamplingStrategy(), # Must specify a strategy for gen slots ) - print(f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]["content"]}\n```") # type: ignore + print( + f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```" + ) # type: ignore # Prompt to the model looked like: # ``` # Your task is to imitate the output of the following function for the given arguments. @@ -48,17 +48,18 @@ def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: print("\nOutput sentiment is:", sentiment_component) - # We can also force a precondition failure. try: sentiment_component = classify_sentiment( m, text="I hate this!", - # Requirement always fails to validate given the lambda. precondition_requirements=[ - Requirement("the text arg should be only one word", validation_fn=simple_validate(lambda x: (False, "Forced to fail!"))) - ], + Requirement( + "the text arg should be only one word", + validation_fn=simple_validate(lambda x: (False, "Forced to fail!")), + ) + ], ) except PreconditionException as e: print(f"exception: {str(e)}")