diff --git a/docs/examples/generative_slots/generative_slots.py b/docs/examples/generative_slots/generative_slots.py index 1b77de44..2e1f5e40 100644 --- a/docs/examples/generative_slots/generative_slots.py +++ b/docs/examples/generative_slots/generative_slots.py @@ -20,6 +20,7 @@ def generate_summary(text: str) -> str: print("Output sentiment is : ", sentiment_component) summary = generate_summary( + m=m, text=""" The eagle rays are a group of cartilaginous fishes in the family Myliobatidae, consisting mostly of large species living in the open ocean rather than on the sea bottom. @@ -28,6 +29,6 @@ def generate_summary(text: str) -> str: surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies. They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to 5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan. - """ + """, ) print("Generated summary is :", summary) diff --git a/docs/examples/generative_slots/generative_slots_with_requirements.py b/docs/examples/generative_slots/generative_slots_with_requirements.py new file mode 100644 index 00000000..7304bc3c --- /dev/null +++ b/docs/examples/generative_slots/generative_slots_with_requirements.py @@ -0,0 +1,70 @@ +from typing import Literal + +from mellea import generative, start_session +from mellea.stdlib.genslot import PreconditionException +from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.stdlib.sampling.base import RejectionSamplingStrategy + + +@generative +def classify_sentiment(text: str) -> Literal["positive", "negative", "unknown"]: + """Classify the sentiment of the text.""" + ... + + +if __name__ == "__main__": + m = start_session() + + # Add preconditions and requirements. + sentiment_component = classify_sentiment( + m, + text="I love this!", + # Preconditions are only checked with basic validation. Don't use the strategy. + precondition_requirements=["the text arg should be less than 100 words"], + # Reqs to use with the strategy. You could also just remove "unknown" from the structured output for this. + requirements=["avoid classifying the sentiment as unknown"], + strategy=RejectionSamplingStrategy(), # Must specify a strategy for gen slots + ) + + print( + f"Prompt to the model looked like:\n```\n{m.last_prompt()[0]['content']}\n```" + ) # type: ignore + # Prompt to the model looked like: + # ``` + # Your task is to imitate the output of the following function for the given arguments. + # Reply Nothing else but the output of the function. + + # Function: + # def classify_sentiment(text: str) -> Literal['positive', 'negative', 'unknown']: + # """Classify the sentiment of the text. + + # Postconditions: + # - avoid classifying the sentiment as unknown + # """ + + # Arguments: + # - text: "I love this!" (type: ) + # ``` + + print("\nOutput sentiment is:", sentiment_component) + + # We can also force a precondition failure. + try: + sentiment_component = classify_sentiment( + m, + text="I hate this!", + # Requirement always fails to validate given the lambda. + precondition_requirements=[ + Requirement( + "the text arg should be only one word", + validation_fn=simple_validate(lambda x: (False, "Forced to fail!")), + ) + ], + ) + except PreconditionException as e: + print(f"exception: {str(e)}") + + # Look at why the precondition validation failed. + print("Failure reasons:") + for val_result in e.validation: + print("-", val_result.reason) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index da87343a..9f1a089b 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -1,16 +1,28 @@ """A method to generate outputs based on python functions and a Generative Slot function.""" -import asyncio +import abc import functools import inspect -from collections.abc import Callable, Coroutine -from copy import deepcopy -from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints +from collections.abc import Awaitable, Callable, Coroutine +from copy import copy, deepcopy +from dataclasses import dataclass, fields +from typing import Any, Generic, ParamSpec, TypedDict, TypeVar, get_type_hints, overload from pydantic import BaseModel, Field, create_model -from mellea.stdlib.base import Component, TemplateRepresentation -from mellea.stdlib.session import MelleaSession, get_session +import mellea.stdlib.funcs as mfuncs +from mellea.backends import Backend +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import ( + CBlock, + Component, + Context, + ModelOutputThunk, + TemplateRepresentation, +) +from mellea.stdlib.requirement import Requirement, ValidationResult, reqify +from mellea.stdlib.sampling.types import SamplingStrategy +from mellea.stdlib.session import MelleaSession P = ParamSpec("P") R = TypeVar("R") @@ -61,6 +73,81 @@ class ArgumentDict(TypedDict): value: str | None +class Argument: + """An Argument Component.""" + + def __init__( + self, + annotation: str | None = None, + name: str | None = None, + value: str | None = None, + ): + """An Argument Component.""" + self._argument_dict: ArgumentDict = { + "name": name, + "annotation": annotation, + "value": value, + } + + +class Arguments(CBlock): + def __init__(self, arguments: list[Argument]): + """Create a textual representation of a list of arguments.""" + # Make meta the original list of arguments and create a list of textual representations. + meta: dict[str, Any] = {} + text_args = [] + for arg in arguments: + assert arg._argument_dict["name"] is not None + meta[arg._argument_dict["name"]] = arg + text_args.append( + f"- {arg._argument_dict['name']}: {arg._argument_dict['value']} (type: {arg._argument_dict['annotation']})" + ) + + super().__init__("\n".join(text_args), meta) + + +class ArgPreconditionRequirement(Requirement): + """Specific requirement with template for validating precondition requirements against a set of args.""" + + def __init__(self, req: Requirement): + """Can only be instantiated from existing requirements. All function calls are delegated to the underlying requirement.""" + self.req = req + + def __getattr__(self, name): + return getattr(self.req, name) + + def __copy__(self): + return ArgPreconditionRequirement(req=self.req) + + def __deepcopy__(self, memo): + return ArgPreconditionRequirement(deepcopy(self.req, memo)) + + +class PreconditionException(Exception): + """Exception raised when validation fails for a generative slot's arguments.""" + + def __init__( + self, message: str, validation_results: list[ValidationResult] + ) -> None: + """Exception raised when validation fails for a generative slot's arguments. + + Args: + message: the error message + validation_results: the list of validation results from the failed preconditions + """ + super().__init__(message) + self.validation = validation_results + + +class Function: + """A Function Component.""" + + def __init__(self, func: Callable): + """A Function Component.""" + self._func: Callable = func + self._function_dict: FunctionDict = describe_function(func) + + def describe_function(func: Callable) -> FunctionDict: """Generates a FunctionDict given a function. @@ -77,22 +164,30 @@ def describe_function(func: Callable) -> FunctionDict: } -def get_annotation(func: Callable, key: str, val: Any) -> str: - """Returns a annotated list of arguments for a given function and list of arguments. +def get_argument(func: Callable, key: str, val: Any) -> Argument: + """Returns an argument given a parameter. + + Note: Performs additional formatting for string objects, putting them in quotes. Args: func : Callable Function - key : Arg keys - val : Arg Values + key : Arg key + val : Arg value Returns: - str: An annotated string for a given func. + Argument: an argument object representing the given parameter. """ sig = inspect.signature(func) param = sig.parameters.get(key) if param and param.annotation is not inspect.Parameter.empty: - return str(param.annotation) - return str(type(val)) + param_type = param.annotation + else: + param_type = type(val) + + if param_type is str: + val = f'"{val!s}"' + + return Argument(str(param_type), key, val) def bind_function_arguments( @@ -114,30 +209,39 @@ def bind_function_arguments( return dict(bound_arguments.arguments) -class Argument: - """An Argument Component.""" +@dataclass +class ExtractedArgs: + """Used to extract the mellea args and original function args. See @generative decorator for additional notes on these fields. - def __init__( - self, - annotation: str | None = None, - name: str | None = None, - value: str | None = None, - ): - """An Argument Component.""" - self._argument_dict: ArgumentDict = { - "name": name, - "annotation": annotation, - "value": value, - } + These args must match those allowed by any overload of GenerativeSlot.__call__. + """ + f_args: tuple[Any, ...] + """*args from the original function, used to detect incorrectly passed args to generative slots""" -class Function: - """A Function Component.""" + f_kwargs: dict[str, Any] + """**kwargs from the original function""" - def __init__(self, func: Callable): - """A Function Component.""" - self._func: Callable = func - self._function_dict: FunctionDict = describe_function(func) + m: MelleaSession | None = None + context: Context | None = None + backend: Backend | None = None + model_options: dict | None = None + strategy: SamplingStrategy | None = None + + precondition_requirements: list[Requirement | str] | None = None + """requirements used to check the input""" + + requirements: list[Requirement | str] | None = None + """requirements used to check the output""" + + def __init__(self): + """Used to extract the mellea args and original function args.""" + self.f_args = tuple() + self.f_kwargs = {} + + +_disallowed_param_names = [field.name for field in fields(ExtractedArgs())] +"""A list of parameter names used by Mellea. Cannot use these in functions decorated with @generative.""" class GenerativeSlot(Component, Generic[P, R]): @@ -148,113 +252,425 @@ def __init__(self, func: Callable[P, R]): Args: func: A callable function + + Raises: + ValueError: if the decorated function has a parameter name used by generative slots """ + sig = inspect.signature(func) + problematic_param_names: list[str] = [] + for param in sig.parameters.keys(): + if param in _disallowed_param_names: + problematic_param_names.append(param) + + if len(problematic_param_names): + raise ValueError( + f"cannot create a generative slot with disallowed parameter names: {problematic_param_names}" + ) + self._function = Function(func) - self._arguments: list[Argument] = [] + self._arguments: Arguments | None = None functools.update_wrapper(self, func) + # Set when calling the decorated func. + self.precondition_requirements: list[Requirement] = [] + self.requirements: list[Requirement] = [] + + @abc.abstractmethod + def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: + """Call the generative slot. See subclasses for more information.""" + ... + + @staticmethod + def extract_args_and_kwargs(*args, **kwargs) -> ExtractedArgs: + """Takes a mix of args and kwargs for both the generative slot and the original function and extracts them. Ensures the original function's args are all kwargs. + + Returns: + ExtractedArgs: a dataclass of the required args for mellea and the original function. + Either session or (backend, context) will be non-None. + + Raises: + TypeError: if any of the original function's parameters were passed as positional args + """ + + def _session_extract_args_and_kwargs( + m: MelleaSession, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, + model_options: dict | None = None, + *args, + **kwargs, + ): + """Helper function for extracting args. Used when a session is passed.""" + extracted = ExtractedArgs() + extracted.m = m + extracted.precondition_requirements = precondition_requirements + extracted.requirements = requirements + extracted.strategy = strategy + extracted.model_options = model_options + extracted.f_args = args + extracted.f_kwargs = kwargs + return extracted + + def _context_backend_extract_args_and_kwargs( + context: Context, + backend: Backend, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, + model_options: dict | None = None, + *args, + **kwargs, + ): + """Helper function for extracting args. Used when a context and a backend are passed.""" + extracted = ExtractedArgs() + extracted.context = context + extracted.backend = backend + extracted.precondition_requirements = precondition_requirements + extracted.requirements = requirements + extracted.strategy = strategy + extracted.model_options = model_options + extracted.f_args = args + extracted.f_kwargs = kwargs + return extracted + + # Determine which overload was used: + # - if there's args, the first arg must either be a `MelleaSession` or a `Context` + # - otherwise, just check the kwargs for a "m" that is type `MelleaSession` + using_session_overload = False + if len(args) > 0: + possible_session = args[0] + else: + possible_session = kwargs.get("m", None) + if isinstance(possible_session, MelleaSession): + using_session_overload = True + + # Call the appropriate function and let python handle the arg/kwarg extraction. + if using_session_overload: + extracted = _session_extract_args_and_kwargs(*args, **kwargs) + else: + extracted = _context_backend_extract_args_and_kwargs(*args, **kwargs) + + if len(extracted.f_args) > 0: + raise TypeError( + "generative slots do not accept positional args from the decorated function; use keyword args instead" + ) + + return extracted + + def parts(self): + """Not implemented.""" + raise NotImplementedError + + def format_for_llm(self) -> TemplateRepresentation: + """Formats the instruction for Formatter use.""" + return TemplateRepresentation( + obj=self, + args={ + "function": self._function._function_dict, + "arguments": self._arguments, + "requirements": [ + r.description + for r in self.requirements + if r.description is not None + and r.description != "" + and not r.check_only + ], # Same conditions on requirements as in instruction. + }, + tools=None, + template_order=["*", "GenerativeSlot"], + ) + + +class SyncGenerativeSlot(GenerativeSlot, Generic[P, R]): + @overload def __call__( self, - m: MelleaSession | None = None, + context: Context, + backend: Backend, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, - ) -> R: + ) -> tuple[R, Context]: ... + + @overload + def __call__( + self, + m: MelleaSession, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> R: ... + + def __call__(self, *args, **kwargs) -> tuple[R, Context] | R: """Call the generative slot. Args: - m: MelleaSession: A mellea session (optional, uses context if None) + m: MelleaSession: A mellea session (optional: must set context and backend if None) + context: the Context object (optional: session must be set if None) + backend: the backend used for generation (optional: session must be set if None) + precondition_requirements: A list of requirements that the genslot inputs are validated against; does not use a sampling strategy. + requirements: A list of requirements that the genslot output can be validated against. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. *args: Additional args to be passed to the func. **kwargs: Additional Kwargs to be passed to the func. Returns: - R: an object with the original return type of the function + Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function + + Raises: + TypeError: if any of the original function's parameters were passed as positional args + PreconditionException: if the precondition validation fails, catch the err to get the validation results """ - if m is None: - m = get_session() + extracted = self.extract_args_and_kwargs(*args, **kwargs) + slot_copy = deepcopy(self) - arguments = bind_function_arguments(self._function._func, *args, **kwargs) + if extracted.requirements is not None: + slot_copy.requirements = [reqify(r) for r in extracted.requirements] + + if extracted.precondition_requirements is not None: + slot_copy.precondition_requirements = [ + ArgPreconditionRequirement(reqify(r)) + for r in extracted.precondition_requirements + ] + + arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) if arguments: + slot_args: list[Argument] = [] for key, val in arguments.items(): - annotation = get_annotation(slot_copy._function._func, key, val) - slot_copy._arguments.append(Argument(annotation, key, val)) + slot_args.append(get_argument(slot_copy._function._func, key, val)) + slot_copy._arguments = Arguments(slot_args) response_model = create_response_format(self._function._func) - response = m.act(slot_copy, format=response_model, model_options=model_options) + # Do precondition validation first. + if slot_copy._arguments is not None: + if extracted.m is not None: + val_results = extracted.m.validate( + reqs=slot_copy.precondition_requirements, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + val_results = mfuncs.validate( + reqs=slot_copy.precondition_requirements, + context=extracted.context, + backend=extracted.backend, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + + # No retries if precondition validation fails. + if not all(bool(val_result) for val_result in val_results): + FancyLogger.get_logger().error( + "generative slot arguments did not satisfy precondition requirements" + ) + raise PreconditionException( + "generative slot arguments did not satisfy precondition requirements", + validation_results=val_results, + ) + + elif len(slot_copy.precondition_requirements) > 0: + FancyLogger.get_logger().warning( + "calling a generative slot with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" + ) + + response, context = None, None + if extracted.m is not None: + response = extracted.m.act( + slot_copy, + requirements=slot_copy.requirements, + strategy=extracted.strategy, + format=response_model, + model_options=extracted.model_options, + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + response, context = mfuncs.act( + slot_copy, + extracted.context, + extracted.backend, + requirements=slot_copy.requirements, + strategy=extracted.strategy, + format=response_model, + model_options=extracted.model_options, + ) function_response: FunctionResponse[R] = response_model.model_validate_json( response.value # type: ignore ) - return function_response.result - - def parts(self): - """Not implemented.""" - raise NotImplementedError - - def format_for_llm(self) -> TemplateRepresentation: - """Formats the instruction for Formatter use.""" - return TemplateRepresentation( - obj=self, - args={ - "function": self._function._function_dict, - "arguments": [a._argument_dict for a in self._arguments], - }, - tools=None, - template_order=["*", "GenerativeSlot"], - ) + if context is None: + return function_response.result + else: + return function_response.result, context class AsyncGenerativeSlot(GenerativeSlot, Generic[P, R]): """A generative slot component that generates asynchronously and returns a coroutine.""" + @overload def __call__( self, - m: MelleaSession | None = None, + context: Context, + backend: Backend, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs, - ) -> Coroutine[Any, Any, R]: + ) -> Coroutine[Any, Any, tuple[R, Context]]: ... + + @overload + def __call__( + self, + m: MelleaSession, + precondition_requirements: list[Requirement | str] | None = None, + requirements: list[Requirement | str] | None = None, + strategy: SamplingStrategy | None = None, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, + ) -> Coroutine[Any, Any, R]: ... + + def __call__(self, *args, **kwargs) -> Coroutine[Any, Any, tuple[R, Context] | R]: """Call the async generative slot. Args: - m: MelleaSession: A mellea session (optional, uses context if None) + m: MelleaSession: A mellea session (optional: must set context and backend if None) + context: the Context object (optional: session must be set if None) + backend: the backend used for generation (optional: session must be set if None) + precondition_requirements: A list of requirements that the genslot inputs are validated against; does not use a sampling strategy. + requirements: A list of requirements that the genslot output can be validated against. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. model_options: Model options to pass to the backend. *args: Additional args to be passed to the func. - **kwargs: Additional Kwargs to be passed to the func + **kwargs: Additional Kwargs to be passed to the func. Returns: Coroutine[Any, Any, R]: a coroutine that returns an object with the original return type of the function + + Raises: + TypeError: if any of the original function's parameters were passed as positional args + PreconditionException: if the precondition validation fails, catch the err to get the validation results """ - if m is None: - m = get_session() + extracted = self.extract_args_and_kwargs(*args, **kwargs) + slot_copy = deepcopy(self) - arguments = bind_function_arguments(self._function._func, *args, **kwargs) + if extracted.requirements is not None: + slot_copy.requirements = [reqify(r) for r in extracted.requirements] + + if extracted.precondition_requirements is not None: + slot_copy.precondition_requirements = [ + ArgPreconditionRequirement(reqify(r)) + for r in extracted.precondition_requirements + ] + + arguments = bind_function_arguments(self._function._func, **extracted.f_kwargs) if arguments: + slot_args: list[Argument] = [] for key, val in arguments.items(): - annotation = get_annotation(slot_copy._function._func, key, val) - slot_copy._arguments.append(Argument(annotation, key, val)) + slot_args.append(get_argument(slot_copy._function._func, key, val)) + slot_copy._arguments = Arguments(slot_args) response_model = create_response_format(self._function._func) # AsyncGenerativeSlots are used with async functions. In order to support that behavior, # they must return a coroutine object. - async def __async_call__() -> R: - # Use the async act func so that control flow doesn't get stuck here in async event loops. - response = await m.aact( - slot_copy, format=response_model, model_options=model_options - ) + async def __async_call__() -> tuple[R, Context] | R: + """Use async calls so that control flow doesn't get stuck here in async event loops.""" + response, context = None, None + + # Do precondition validation first. + if slot_copy._arguments is not None: + if extracted.m is not None: + val_results = await extracted.m.avalidate( + reqs=slot_copy.precondition_requirements, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + val_results = await mfuncs.avalidate( + reqs=slot_copy.precondition_requirements, + context=extracted.context, + backend=extracted.backend, + model_options=extracted.model_options, + output=ModelOutputThunk(slot_copy._arguments.value), + ) + + # No retries if precondition validation fails. + if not all(bool(val_result) for val_result in val_results): + FancyLogger.get_logger().error( + "generative slot arguments did not satisfy precondition requirements" + ) + raise PreconditionException( + "generative slot arguments did not satisfy precondition requirements", + validation_results=val_results, + ) + + elif len(slot_copy.precondition_requirements) > 0: + FancyLogger.get_logger().warning( + "calling a generative slot with precondition requirements but no args to validate the preconditions against; ignoring precondition validation" + ) + + if extracted.m is not None: + response = await extracted.m.aact( + slot_copy, + requirements=slot_copy.requirements, + strategy=extracted.strategy, + format=response_model, + model_options=extracted.model_options, + ) + else: + # We know these aren't None from the `extract_args_and_kwargs` function. + assert extracted.context is not None + assert extracted.backend is not None + response, context = await mfuncs.aact( + slot_copy, + extracted.context, + extracted.backend, + requirements=slot_copy.requirements, + strategy=extracted.strategy, + format=response_model, + model_options=extracted.model_options, + ) function_response: FunctionResponse[R] = response_model.model_validate_json( response.value # type: ignore ) - return function_response.result + + if context is None: + return function_response.result + else: + return function_response.result, context return __async_call__() +@overload +def generative(func: Callable[P, Awaitable[R]]) -> AsyncGenerativeSlot[P, R]: ... # type: ignore + + +@overload +def generative(func: Callable[P, R]) -> SyncGenerativeSlot[P, R]: ... + + def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: """Convert a function into an AI-powered function. @@ -264,13 +680,29 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: that function's behavior. The output is guaranteed to match the return type annotation using structured outputs and automatic validation. - Note: Works with async functions as well. + Notes: + - Works with async functions as well. + - Must pass all parameters for the original function as keyword args. + - Most python type-hinters will not show the default values but will correctly infer them; + this means that you can set default values in the decorated function and the only necessary values will be a session or a (context, backend). Tip: Write the function and docstring in the most Pythonic way possible, not like a prompt. This ensures the function is well-documented, easily understood, and familiar to any Python developer. The more natural and conventional your function definition, the better the AI will understand and imitate it. + The new function has the following additional args: + *m*: MelleaSession: A mellea session (optional: must set context and backend if None) + *context*: Context: the Context object (optional: session must be set if None) + *backend*: Backend: the backend used for generation (optional: session must be set if None) + *precondition_requirements*: list[Requirements | str] | None: A list of requirements that the genslot inputs are validated against; raises an err if not met. + *requirements*: list[Requirement | str] | None: A list of requirements that the genslot output can be validated against. + *strategy*: SamplingStrategy | None: A SamplingStrategy that describes the strategy for validating and repairing/retrying. None means that no particular sampling strategy is used. + *model_options*: dict | None: Model options to pass to the backend. + + The requirements and validation for the generative function operate over a textual representation + of the arguments / outputs (not their python objects). + Args: func: Function with docstring and type hints. Implementation can be empty (...). @@ -279,7 +711,10 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: original function's signature and docstring. Raises: - ValidationError: if the generated output cannot be parsed into the expected return type. Typically happens when the token limit for the generated output results in invalid json. + ValueError: (raised by @generative) if the decorated function has a parameter name used by generative slots + ValidationError: (raised when calling the generative slot) if the generated output cannot be parsed into the expected return type. Typically happens when the token limit for the generated output results in invalid json. + TypeError: (raised when calling the generative slot) if any of the original function's parameters were passed as positional args + PreconditionException: (raised when calling the generative slot) if the precondition validation of the args fails; catch the exception to get the validation results Examples: >>> from mellea import generative, start_session @@ -289,7 +724,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... '''Generate a concise summary of the input text.''' ... ... >>> - >>> summary = summarize_text(session, "Long text...", max_words=30) + >>> summary = summarize_text(session, text="Long text...", max_words=30) >>> from typing import List >>> from dataclasses import dataclass @@ -313,7 +748,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... ''' ... ... >>> - >>> tasks = await create_project_tasks(session, "Build a web app", 5) + >>> tasks = await create_project_tasks(session, project_desc="Build a web app", count=5) >>> @generative ... def analyze_code_quality(code: str) -> Dict[str, Any]: @@ -333,7 +768,7 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: >>> >>> analysis = analyze_code_quality( ... session, - ... "def factorial(n): return n * factorial(n-1)", + ... code="def factorial(n): return n * factorial(n-1)", ... model_options={"temperature": 0.3} ... ) @@ -355,13 +790,13 @@ def generative(func: Callable[P, R]) -> GenerativeSlot[P, R]: ... ''' ... ... >>> - >>> reasoning = generate_chain_of_thought(session, "How to optimize a slow database query?") + >>> reasoning = generate_chain_of_thought(session, problem="How to optimize a slow database query?") """ if inspect.iscoroutinefunction(func): return AsyncGenerativeSlot(func) else: - return GenerativeSlot(func) + return SyncGenerativeSlot(func) -# Export the decorator as the interface -__all__ = ["generative"] +# Export the decorator as the interface. Export the specific exception for debugging. +__all__ = ["PreconditionException", "generative"] diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 6520a7da..e5b5bc5e 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -166,7 +166,7 @@ async def sample( context=result_ctx, backend=backend, output=result, - format=format, + format=None, model_options=model_options, # tool_calls=tool_calls # Don't support using tool calls in validation strategies. ) diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index 59c402b2..0f8aaffb 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -132,7 +132,7 @@ async def sample( context=result_ctx, backend=backend, output=result, - format=format, + format=None, model_options=model_options, input=next_action._description, # type: ignore # tool_calls=tool_calls # Don't support using tool calls in validation strategies. diff --git a/mellea/templates/prompts/default/ArgPreconditionRequirement.jinja2 b/mellea/templates/prompts/default/ArgPreconditionRequirement.jinja2 new file mode 100644 index 00000000..7c96c046 --- /dev/null +++ b/mellea/templates/prompts/default/ArgPreconditionRequirement.jinja2 @@ -0,0 +1,17 @@ +Please check if the following arguments satisfy the precondition. +Reply with 'yes' if the precondition is satisfied and 'no' otherwise. +Do not include any other text in your response. + +{%- block output -%} +{% if output %} + +Arguments: +{{ output -}} +{%- endif -%} +{% endblock output%} + +{%- block description %} +{% if description %} +Precondition: {{ description -}} +{%- endif -%} +{% endblock description %} diff --git a/mellea/templates/prompts/default/GenerativeSlot.jinja2 b/mellea/templates/prompts/default/GenerativeSlot.jinja2 index 309bfc9b..8305dedc 100644 --- a/mellea/templates/prompts/default/GenerativeSlot.jinja2 +++ b/mellea/templates/prompts/default/GenerativeSlot.jinja2 @@ -2,13 +2,19 @@ Your task is to imitate the output of the following function for the given argum Reply Nothing else but the output of the function. Function: -def {{ function.name }}({{ function.signature }}) - """{{ function.docstring | default("No documentation provided.") }}""" +def {{ function.name }}{{ function.signature }}: + """{{ function.docstring | default("No documentation provided.") -}} + {% if requirements|length > 0 %} -{%- if arguments|length > 0 %} + Postconditions: + {%- for req in requirements %} + - {{ req -}} + {% endfor %} + {% endif -%} + """ + +{%- if arguments is not none %} Arguments: -{%- for arg in arguments %} -- {{ arg.name }}: {{ arg.value }} (type: {{ arg.annotation -}}) -{%- endfor -%} +{{ arguments }} {%- endif -%} \ No newline at end of file diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index c9695260..78f43f14 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -2,8 +2,25 @@ import pytest from typing import Literal from mellea import generative, start_session -from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot - +from mellea.backends.model_ids import META_LLAMA_3_2_1B +from mellea.backends.ollama import OllamaModelBackend +from mellea.stdlib.base import ChatContext, Context +from mellea.stdlib.genslot import AsyncGenerativeSlot, GenerativeSlot, PreconditionException, SyncGenerativeSlot +from mellea.stdlib.requirement import Requirement, simple_validate +from mellea.stdlib.sampling.base import RejectionSamplingStrategy +from mellea.stdlib.session import MelleaSession + +@pytest.fixture(scope="module") +def backend(gh_run: int): + """Shared backend.""" + if gh_run == 1: + return OllamaModelBackend( + model_id=META_LLAMA_3_2_1B.ollama_name, # type: ignore + ) + else: + return OllamaModelBackend( + model_id="granite3.3:8b", + ) @generative def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... @@ -33,7 +50,7 @@ def test_gen_slot_output(classify_sentiment_output): def test_func(session): - assert isinstance(write_me_an_email, GenerativeSlot) and not isinstance(write_me_an_email, AsyncGenerativeSlot) + assert isinstance(write_me_an_email, SyncGenerativeSlot) write_email_component = write_me_an_email(session) assert isinstance(write_email_component, str) @@ -48,18 +65,128 @@ def test_gen_slot_logs(classify_sentiment_output, session): assert isinstance(last_prompt, dict) assert set(last_prompt.keys()) == {"role", "content", "images"} +def test_gen_slot_with_context_and_backend(session): + email, context = write_me_an_email(context=session.ctx, backend=session.backend) + assert isinstance(email, str) + assert isinstance(context, Context) + async def test_async_gen_slot(session): assert isinstance(async_write_short_sentence, AsyncGenerativeSlot) r1 = async_write_short_sentence(session, topic="cats") r2 = async_write_short_sentence(session, topic="dogs") - r3 = await async_write_short_sentence(session, topic="fish") + r3, c3 = await async_write_short_sentence(context=session.ctx, backend=session.backend, topic="fish") results = await asyncio.gather(r1, r2) assert isinstance(r3, str) + assert isinstance(c3, Context) assert len(results) == 2 +@pytest.mark.parametrize( + "arg_choices,kwarg_choices,errs", + [ + pytest.param(["m"], ["func1", "func2", "func3"], False, id="session"), + pytest.param(["context"], ["backend"], False, id="context and backend"), + pytest.param(["backend"], ["func1", "func2", "func3"], True, id="backend without context"), + pytest.param(["m"], ["m"], True, id="duplicate arg and kwarg"), + pytest.param(["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], [], True, id="original func args as positional args"), + pytest.param([], ["m", "func1", "func2", "func3"], False, id="session and func as kwargs"), + pytest.param([], ["m", "precondition_requirements", "requirements", "strategy", "model_options", "func1", "func2", "func3"], False, id="all kwargs"), + pytest.param([], ["func1", "m", "func2", "requirements", "func3"], False, id="interspersed kwargs"), + pytest.param([], [], True, id="missing required args") + ] +) +def test_arg_extraction(backend, arg_choices, kwarg_choices, errs): + """Tests the internal extract_args_and_kwargs function. + + This function has to test a large number of input combinations; as a result, + it uses a parameterization scheme. It takes a list and a dict. Each contains + strings corresponding to the possible args/kwargs below. Order matters in the list. + See the param id for an idea of what the test does. + + Python should catch most of these issues itself. We have to manually raise an exception for + the arguments of the original function being positional. + """ + + # List of all needed values. + backend = backend + ctx = ChatContext() + session = MelleaSession(backend, ctx) + precondition_requirements = ["precondition"] + requirements = None + strategy = RejectionSamplingStrategy() + model_options = {"test": 1} + func1 = 1 + func2 = True + func3 = "func3" + + # Lookup table by name. + vals = { + "m": session, + "backend": backend, + "context": ctx, + "precondition_requirements": precondition_requirements, + "requirements": requirements, + "strategy": strategy, + "model_options": model_options, + "func1": func1, + "func2": func2, + "func3": func3, + } + + args = [] + for arg in arg_choices: + args.append(vals[arg]) + + kwargs = {} + for kwarg in kwarg_choices: + kwargs[kwarg] = vals[kwarg] + + # Run the extraction and check for the (un-)expected exception. + found_err = False + err = None + try: + GenerativeSlot.extract_args_and_kwargs(*args, **kwargs) + except Exception as e: + found_err = True + err = e + + if errs: + assert found_err, "expected an exception and got none" + else: + assert not found_err, f"got unexpected err: {err}" + +def test_disallowed_parameter_names(): + with pytest.raises(ValueError): + @generative + def test(backend): + ... + +def test_precondition_failure(session): + with pytest.raises(PreconditionException): + classify_sentiment( + m=session, + text="hello", + precondition_requirements=[ + Requirement("forced failure", validation_fn=simple_validate(lambda x: (False, ""))) + ] + ) + +def test_requirement(session): + classify_sentiment( + m=session, + text="hello", + requirements=["req1", "req2", Requirement("req3")] + ) + +def test_with_no_args(session): + @generative + def generate_text() -> str: + """Generate text!""" + ... + + generate_text(m=session) if __name__ == "__main__": pytest.main([__file__])