diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/functional.py similarity index 73% rename from mellea/stdlib/funcs.py rename to mellea/stdlib/functional.py index 5990edb2..243eab5a 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/functional.py @@ -1,5 +1,7 @@ """Functions for Mellea operations like Instruct, Chat, etc...""" +# ruff: noqa: D103 + from __future__ import annotations import asyncio @@ -33,6 +35,63 @@ SamplingStrategy, ) +DOCS = { + # from `act` + "action": "the Component from which to generate.", + "context": "the context being used as a history from which to generate the response.", + "backend": "the backend used to generate the response.", + # "requirements": "used as additional requirements when a sampling strategy is provided.", + "strategy": "a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used.", + "return_sampling_results": "attach the (successful and failed) sampling attempts to the results.", + "format": "if set, the BaseModel to use for constrained decoding.", + "model_options": "additional model options, which will upsert into the model/backend's defaults.", + "tool_calls": "if true, tool calling is enabled.", + # from `instruct` + "description": "The description of the instruction.", + "requirements": "A list of requirements that the instruction can be validated against.", + "icl_examples": "A list of in-context-learning examples that the instruction can be validated against.", + "grounding_context": "A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple.", + "user_variables": "A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings.", + "prefix": "A prefix string or ContentBlock to use when generating the instruction.", + "output_prefix": "A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this.", + "images": "A list of images to be used in the instruction or None if none.", + # from `chat` + "content": "A message content string in a chat", + # from `query` + "obj ": "The object to be queried. It should be an instance of MObject or can be converted to one if necessary.", + "query": " The string representing the query to be executed against the object.", + # from `aact` + "silence_context_type_warning": "if called directly from an asynchronous function, will log a warning if not using a SimpleContext", + # from `transform` + "transformation": " The string representing the query to be executed against the object.", +} + + +def format_docs(main: str, argnames: list[str], returns: str, **custom: str) -> str: + """Format a docstring by looking up the DOCS dictionary in this module. + + DOCS is a dictionary that maps argument names to documentation strings. + Each string in argnames specifies a key in DOCS. + To add a custom documentation to an argument that is not present in DOCS, + use the keyword arguments **custom . + **custom can also be used for superseding an existing docstring in DOCS. + + Args: + main: str - main documentation. + argnames: list[str] - names of the arguments. + returns: str - return value documentation. + **custom: Additional custom entries to add to DOCS for this function. + + Returns: + A formatted string. + """ + docs = DOCS.copy() + docs.update(custom) + argdocs = "\n ".join( + [argname + ": " + docs[argname] for argname in argnames] + ) + return f"\n {main}\n Args:\n {argdocs}\n\n Returns:\n {returns}\n " + @overload def act( @@ -76,22 +135,6 @@ def act( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: - """Runs a generic action, and adds both the action and the result to the context. - - Args: - action: the Component from which to generate. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - requirements: used as additional requirements when a sampling strategy is provided. - strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: if set, the BaseModel to use for constrained decoding. - model_options: additional model options, which will upsert into the model/backend's defaults. - tool_calls: if true, tool calling is enabled. - - Returns: - A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. - """ out = _run_async_in_thread( aact( action, @@ -171,28 +214,6 @@ def instruct( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: - """Generates from an instruction. - - Args: - description: The description of the instruction. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - requirements: A list of requirements that the instruction can be validated against. - icl_examples: A list of in-context-learning examples that the instruction can be validated against. - grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. - user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. - prefix: A prefix string or ContentBlock to use when generating the instruction. - output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. - strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: If set, the BaseModel to use for constrained decoding. - model_options: Additional model options, which will upsert into the model/backend's defaults. - tool_calls: If true, tool calling is enabled. - images: A list of images to be used in the instruction or None if none. - - Returns: - A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. - """ requirements = [] if requirements is None else requirements icl_examples = [] if icl_examples is None else icl_examples grounding_context = dict() if grounding_context is None else grounding_context @@ -236,7 +257,6 @@ def chat( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[Message, Context]: - """Sends a simple chat message and returns the response. Adds both messages to the Context.""" if user_variables is not None: content_resolved = Instruction.apply_user_dict_from_jinja( user_variables, content @@ -273,7 +293,6 @@ def validate( | None = None, # TODO: Can we get rid of gen logs here and in act? input: CBlock | None = None, ) -> list[ValidationResult]: - """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" # Run everything in the specific event loop for this session. out = _run_async_in_thread( @@ -303,20 +322,6 @@ def query( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk, Context]: - """Query method for retrieving information from an object. - - Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - query: The string representing the query to be executed against the object. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - format: format for output parsing. - model_options: Model options to pass to the backend. - tool_calls: If true, the model may make tool calls. Defaults to False. - - Returns: - ModelOutputThunk: The result of the query as processed by the backend. - """ if not isinstance(obj, MObjectProtocol): obj = mify(obj) @@ -344,21 +349,6 @@ def transform( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, ) -> tuple[ModelOutputThunk | Any, Context]: - """Transform method for creating a new object with the transformation applied. - - Args: - obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - transformation: The string representing the query to be executed against the object. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - format: format for output parsing; usually not needed with transform. - model_options: Model options to pass to the backend. - - Returns: - (ModelOutputThunk | Any, Context): The result of the transformation as processed by the backend. If no tools were called, - the return type will be always be (ModelOutputThunk, Context). If a tool was called, the return type will be the return type - of the function called, usually the type of the object passed in. - """ if not isinstance(obj, MObjectProtocol): obj = mify(obj) @@ -461,23 +451,6 @@ async def aact( tool_calls: bool = False, silence_context_type_warning: bool = False, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: - """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. - - Args: - action: the Component from which to generate. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - requirements: used as additional requirements when a sampling strategy is provided - strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: if set, the BaseModel to use for constrained decoding. - model_options: additional model options, which will upsert into the model/backend's defaults. - tool_calls: if true, tool calling is enabled. - silence_context_type_warning: if called directly from an asynchronous function, will log a warning if not using a SimpleContext - - Returns: - A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. - """ if not silence_context_type_warning and not isinstance(context, SimpleContext): FancyLogger().get_logger().warning( "Not using a SimpleContext with asynchronous requests could cause unexpected results due to stale contexts. Ensure you await between requests." @@ -609,28 +582,6 @@ async def ainstruct( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk, Context] | SamplingResult: - """Generates from an instruction. - - Args: - description: The description of the instruction. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - requirements: A list of requirements that the instruction can be validated against. - icl_examples: A list of in-context-learning examples that the instruction can be validated against. - grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. - user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. - prefix: A prefix string or ContentBlock to use when generating the instruction. - output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. - strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: If set, the BaseModel to use for constrained decoding. - model_options: Additional model options, which will upsert into the model/backend's defaults. - tool_calls: If true, tool calling is enabled. - images: A list of images to be used in the instruction or None if none. - - Returns: - A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`. - """ requirements = [] if requirements is None else requirements icl_examples = [] if icl_examples is None else icl_examples grounding_context = dict() if grounding_context is None else grounding_context @@ -674,7 +625,6 @@ async def achat( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[Message, Context]: - """Sends a simple chat message and returns the response. Adds both messages to the Context.""" if user_variables is not None: content_resolved = Instruction.apply_user_dict_from_jinja( user_variables, content @@ -710,7 +660,6 @@ async def avalidate( generate_logs: list[GenerateLog] | None = None, input: CBlock | None = None, ) -> list[ValidationResult]: - """Asynchronous version of .validate; validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" # Turn a solitary requirement in to a list of requirements, and then reqify if needed. reqs = [reqs] if not isinstance(reqs, list) else reqs reqs = [Requirement(req) if type(req) is str else req for req in reqs] @@ -766,20 +715,6 @@ async def aquery( model_options: dict | None = None, tool_calls: bool = False, ) -> tuple[ModelOutputThunk, Context]: - """Query method for retrieving information from an object. - - Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - query: The string representing the query to be executed against the object. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - format: format for output parsing. - model_options: Model options to pass to the backend. - tool_calls: If true, the model may make tool calls. Defaults to False. - - Returns: - ModelOutputThunk: The result of the query as processed by the backend. - """ if not isinstance(obj, MObjectProtocol): obj = mify(obj) @@ -807,21 +742,6 @@ async def atransform( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, ) -> tuple[ModelOutputThunk | Any, Context]: - """Transform method for creating a new object with the transformation applied. - - Args: - obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - transformation: The string representing the query to be executed against the object. - context: the context being used as a history from which to generate the response. - backend: the backend used to generate the response. - format: format for output parsing; usually not needed with transform. - model_options: Model options to pass to the backend. - - Returns: - ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, - the return type will be always be ModelOutputThunk. If a tool was called, the return type will be the return type - of the function called, usually the type of the object passed in. - """ if not isinstance(obj, MObjectProtocol): obj = mify(obj) @@ -936,3 +856,172 @@ def _call_tools(result: ModelOutputThunk, backend: Backend) -> list[ToolMessage] ) ) return outputs + + +act.__doc__ = format_docs( + "Runs a generic action, and adds both the action and the result to the context.", + [ + "action", + "context", + "backend", + "requirements", + "strategy", + "return_sampling_results", + "format", + "model_options", + "tool_calls", + ], + "A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.", + requirements="used as additional requirements when a sampling strategy is provided.", +) + +instruct.__doc__ = format_docs( + "Generates from an instruction.", + [ + "description", + "context", + "backend", + "requirements", + "icl_examples", + "grounding_context", + "user_variables", + "prefix", + "output_prefix", + "strategy", + "return_sampling_results", + "format", + "model_options", + "tool_calls", + "images", + ], + "A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.", +) + +chat.__doc__ = format_docs( + "Sends a simple chat message and returns the response. Adds both messages to the Context.", + [ + "content", + "context", + "backend", + "role", + "images", + "user_variables", + "format", + "model_options", + "tool_calls", + ], + "A (Message, Context).", +) + +validate.__doc__ = format_docs( + "Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).", + [ + "reqs", + "context", + "backend", + "output", + "format", + "model_options", + "generate_logs", + ], + "A list of ValidationResult.", + reqs=DOCS["requirements"], +) + +query.__doc__ = format_docs( + "Query method for retrieving information from an object.", + ["obj", "query", "context", "backend", "format", "model_options", "tool_calls"], + "ModelOutputThunk: The result of the query as processed by the backend.", +) + +transform.__doc__ = format_docs( + "Transform method for creating a new object with the transformation applied.", + ["obj", "transformation", "context", "backend", "format", "model_options"], + "(ModelOutputThunk | Any, Context): The result of the transformation as processed by the backend. If no tools were called, " + "the return type will be always be (ModelOutputThunk, Context). If a tool was called, the return type will be the return type " + "of the function called, usually the type of the object passed in.", +) + +aact.__doc__ = format_docs( + "Asynchronous version of act.", + [ + "action", + "context", + "backend", + "requirements", + "strategy", + "return_sampling_results", + "format", + "model_options", + "tool_calls", + ], + "A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.", + requirements="used as additional requirements when a sampling strategy is provided.", +) + +ainstruct.__doc__ = format_docs( + "Asynchronous version of instruct.", + [ + "description", + "context", + "backend", + "requirements", + "icl_examples", + "grounding_context", + "user_variables", + "prefix", + "output_prefix", + "strategy", + "return_sampling_results", + "format", + "model_options", + "tool_calls", + "images", + ], + "A (ModelOutputThunk, Context) if `return_sampling_results` is `False`, else returns a `SamplingResult`.", +) + +achat.__doc__ = format_docs( + "Asynchronous version of chat.", + [ + "content", + "context", + "backend", + "role", + "images", + "user_variables", + "format", + "model_options", + "tool_calls", + ], + "A (Message, Context).", +) + +avalidate.__doc__ = format_docs( + "Asynchronous version of validate.", + [ + "reqs", + "context", + "backend", + "output", + "format", + "model_options", + "generate_logs", + ], + "A list of ValidationResult.", + reqs=DOCS["requirements"], +) + +aquery.__doc__ = format_docs( + "Asynchronous version of query.", + ["obj", "query", "context", "backend", "format", "model_options", "tool_calls"], + "ModelOutputThunk: The result of the query as processed by the backend.", +) + +atransform.__doc__ = format_docs( + "Asynchronous version of transform.", + ["obj", "transformation", "context", "backend", "format", "model_options"], + "(ModelOutputThunk | Any, Context): The result of the transformation as processed by the backend. If no tools were called, " + "the return type will be always be (ModelOutputThunk, Context). If a tool was called, the return type will be the return type " + "of the function called, usually the type of the object passed in.", +) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index da87343a..91106366 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -182,7 +182,7 @@ def __call__( response_model = create_response_format(self._function._func) - response = m.act(slot_copy, format=response_model, model_options=model_options) + response = m.act(slot_copy, format=response_model, model_options=model_options) # type:ignore function_response: FunctionResponse[R] = response_model.model_validate_json( response.value # type: ignore @@ -243,7 +243,7 @@ def __call__( # they must return a coroutine object. async def __async_call__() -> R: # Use the async act func so that control flow doesn't get stuck here in async event loops. - response = await m.aact( + response = await m.aact( # type: ignore slot_copy, format=response_model, model_options=model_options ) diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 6520a7da..43608733 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -5,9 +5,9 @@ import tqdm +import mellea.stdlib.functional as mfuncs from mellea.backends import Backend, BaseModelSubclass from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib import funcs as mfuncs from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index 59c402b2..1f11d157 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -4,10 +4,10 @@ import tqdm +import mellea.stdlib.functional as mfuncs from mellea.backends import Backend, BaseModelSubclass from mellea.helpers.async_helpers import wait_for_all_mots from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib import funcs as mfuncs from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 2a63a71a..01f33e53 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -3,12 +3,15 @@ from __future__ import annotations import contextvars +import functools +import inspect +from collections.abc import Callable from copy import copy from typing import Any, Literal, overload from PIL import Image as PILImage -import mellea.stdlib.funcs as mfuncs +import mellea.stdlib.functional as mfuncs from mellea.backends import Backend, BaseModelSubclass from mellea.backends.model_ids import ( IBM_GRANITE_3_3_8B, @@ -236,573 +239,48 @@ def cleanup(self) -> None: if hasattr(self.backend, "close"): self.backend.close() # type: ignore - @overload - def act( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[False] = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk: ... - - @overload - def act( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[True], - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> SamplingResult: ... - - def act( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: bool = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: - """Runs a generic action, and adds both the action and the result to the context. + @classmethod + def register(cls, fn: Callable, set_context: bool = True): + """Registers fn as a new method to MelleaSession. - Args: - action: the Component from which to generate. - requirements: used as additional requirements when a sampling strategy is provided - strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: if set, the BaseModel to use for constrained decoding. - model_options: additional model options, which will upsert into the model/backend's defaults. - tool_calls: if true, tool calling is enabled. - - Returns: - A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. - """ - r = mfuncs.act( - action, - context=self.ctx, - backend=self.backend, - requirements=requirements, - strategy=strategy, - return_sampling_results=return_sampling_results, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) # type: ignore - - if isinstance(r, SamplingResult): - self.ctx = r.result_ctx - return r - else: - result, context = r - self.ctx = context - return result - - @overload - def instruct( - self, - description: str, - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[False] = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk: ... - - @overload - def instruct( - self, - description: str, - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[True], - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> SamplingResult: ... - - def instruct( - self, - description: str, - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: bool = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: - """Generates from an instruction. - - Args: - description: The description of the instruction. - requirements: A list of requirements that the instruction can be validated against. - icl_examples: A list of in-context-learning examples that the instruction can be validated against. - grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. - user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. - prefix: A prefix string or ContentBlock to use when generating the instruction. - output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. - strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: If set, the BaseModel to use for constrained decoding. - model_options: Additional model options, which will upsert into the model/backend's defaults. - tool_calls: If true, tool calling is enabled. - images: A list of images to be used in the instruction or None if none. - """ - r = mfuncs.instruct( - description, - context=self.ctx, - backend=self.backend, - images=images, - requirements=requirements, - icl_examples=icl_examples, - grounding_context=grounding_context, - user_variables=user_variables, - prefix=prefix, - output_prefix=output_prefix, - strategy=strategy, - return_sampling_results=return_sampling_results, # type: ignore - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - - if isinstance(r, SamplingResult): - self.ctx = r.result_ctx - return r - else: - # It's a tuple[ModelOutputThunk, Context]. - result, context = r - self.ctx = context - return result - - def chat( - self, - content: str, - role: Message.Role = "user", - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - user_variables: dict[str, str] | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> Message: - """Sends a simple chat message and returns the response. Adds both messages to the Context.""" - result, context = mfuncs.chat( - content=content, - context=self.ctx, - backend=self.backend, - role=role, - images=images, - user_variables=user_variables, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - - self.ctx = context - return result - - def validate( - self, - reqs: Requirement | list[Requirement], - *, - output: CBlock | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, - input: CBlock | None = None, - ) -> list[ValidationResult]: - """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - return mfuncs.validate( - reqs=reqs, - context=self.ctx, - backend=self.backend, - output=output, - format=format, - model_options=model_options, - generate_logs=generate_logs, - input=input, - ) - - def query( - self, - obj: Any, - query: str, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk: - """Query method for retrieving information from an object. - - Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - query: The string representing the query to be executed against the object. - format: format for output parsing. - model_options: Model options to pass to the backend. - tool_calls: If true, the model may make tool calls. Defaults to False. - - Returns: - ModelOutputThunk: The result of the query as processed by the backend. - """ - result, context = mfuncs.query( - obj=obj, - query=query, - context=self.ctx, - backend=self.backend, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - self.ctx = context - return result - - def transform( - self, - obj: Any, - transformation: str, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - ) -> ModelOutputThunk | Any: - """Transform method for creating a new object with the transformation applied. - - Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - transformation: The string representing the query to be executed against the object. - format: format for output parsing; usually not needed with transform. - model_options: Model options to pass to the backend. - - Returns: - ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, - the return type will be always be ModelOutputThunk. If a tool was called, the return type will be the return type - of the function called, usually the type of the object passed in. - """ - result, context = mfuncs.transform( - obj=obj, - transformation=transformation, - context=self.ctx, - backend=self.backend, - format=format, - model_options=model_options, - ) - self.ctx = context - return result - - @overload - async def aact( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[False] = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk: ... - - @overload - async def aact( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[True], - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> SamplingResult: ... - - async def aact( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: bool = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: - """Runs a generic action, and adds both the action and the result to the context. - - Args: - action: the Component from which to generate. - requirements: used as additional requirements when a sampling strategy is provided - strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: if set, the BaseModel to use for constrained decoding. - model_options: additional model options, which will upsert into the model/backend's defaults. - tool_calls: if true, tool calling is enabled. - - Returns: - A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + The function fn must accept `backend` and `context` arguments. """ - r = await mfuncs.aact( - action, - context=self.ctx, - backend=self.backend, - requirements=requirements, - strategy=strategy, - return_sampling_results=return_sampling_results, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) # type: ignore - - if isinstance(r, SamplingResult): - self.ctx = r.result_ctx - return r - else: - result, context = r - self.ctx = context - return result - - @overload - async def ainstruct( - self, - description: str, - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[False] = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk: ... - - @overload - async def ainstruct( - self, - description: str, - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: Literal[True], - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> SamplingResult: ... - - async def ainstruct( - self, - description: str, - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2), - return_sampling_results: bool = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: - """Generates from an instruction. - Args: - description: The description of the instruction. - requirements: A list of requirements that the instruction can be validated against. - icl_examples: A list of in-context-learning examples that the instruction can be validated against. - grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. - user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. - prefix: A prefix string or ContentBlock to use when generating the instruction. - output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. - strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: If set, the BaseModel to use for constrained decoding. - model_options: Additional model options, which will upsert into the model/backend's defaults. - tool_calls: If true, tool calling is enabled. - images: A list of images to be used in the instruction or None if none. - """ - r = await mfuncs.ainstruct( - description, - context=self.ctx, - backend=self.backend, - images=images, - requirements=requirements, - icl_examples=icl_examples, - grounding_context=grounding_context, - user_variables=user_variables, - prefix=prefix, - output_prefix=output_prefix, - strategy=strategy, - return_sampling_results=return_sampling_results, # type: ignore - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - - if isinstance(r, SamplingResult): - self.ctx = r.result_ctx - return r + def postprocess(self, r): + if set_context: + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + result, context = r + self.ctx = context + return result + else: + return r + + if inspect.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def wrapper(self, *args, **kwargs): + return postprocess( + self, + await fn(backend=self.backend, context=self.ctx, *args, **kwargs), + ) else: - # It's a tuple[ModelOutputThunk, Context]. - result, context = r - self.ctx = context - return result - - async def achat( - self, - content: str, - role: Message.Role = "user", - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - user_variables: dict[str, str] | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> Message: - """Sends a simple chat message and returns the response. Adds both messages to the Context.""" - result, context = await mfuncs.achat( - content=content, - context=self.ctx, - backend=self.backend, - role=role, - images=images, - user_variables=user_variables, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - - self.ctx = context - return result - - async def avalidate( - self, - reqs: Requirement | list[Requirement], - *, - output: CBlock | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, - input: CBlock | None = None, - ) -> list[ValidationResult]: - """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - return await mfuncs.avalidate( - reqs=reqs, - context=self.ctx, - backend=self.backend, - output=output, - format=format, - model_options=model_options, - generate_logs=generate_logs, - input=input, - ) - async def aquery( - self, - obj: Any, - query: str, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk: - """Query method for retrieving information from an object. + @functools.wraps(fn) + def wrapper(self, *args, **kwargs): + return postprocess( + self, fn(backend=self.backend, context=self.ctx, *args, **kwargs) + ) - Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - query: The string representing the query to be executed against the object. - format: format for output parsing. - model_options: Model options to pass to the backend. - tool_calls: If true, the model may make tool calls. Defaults to False. + setattr(cls, fn.__name__, wrapper) - Returns: - ModelOutputThunk: The result of the query as processed by the backend. - """ - result, context = await mfuncs.aquery( - obj=obj, - query=query, - context=self.ctx, - backend=self.backend, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - self.ctx = context - return result - - async def atransform( - self, - obj: Any, - transformation: str, - *, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - ) -> ModelOutputThunk | Any: - """Transform method for creating a new object with the transformation applied. - - Args: - obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. - transformation: The string representing the query to be executed against the object. - format: format for output parsing; usually not needed with transform. - model_options: Model options to pass to the backend. - - Returns: - ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, - the return type will be always be ModelOutputThunk. If a tool was called, the return type will be the return type - of the function called, usually the type of the object passed in. - """ - result, context = await mfuncs.atransform( - obj=obj, - transformation=transformation, - context=self.ctx, - backend=self.backend, - format=format, - model_options=model_options, - ) - self.ctx = context - return result + @classmethod + def powerup(cls, powerup_cls: type): + """Appends methods in a class object `powerup_cls` to MelleaSession.""" + for name, fn in inspect.getmembers(powerup_cls, predicate=inspect.isfunction): + setattr(cls, name, fn) # ############################### # Convenience functions @@ -824,3 +302,18 @@ def last_prompt(self) -> str | list[dict] | None: if isinstance(last_el, GenerateLog): return last_el.prompt return None + + +MelleaSession.register(mfuncs.act) +MelleaSession.register(mfuncs.instruct) +MelleaSession.register(mfuncs.chat) +MelleaSession.register(mfuncs.validate, set_context=False) +MelleaSession.register(mfuncs.query) +MelleaSession.register(mfuncs.transform) + +MelleaSession.register(mfuncs.aact) +MelleaSession.register(mfuncs.ainstruct) +MelleaSession.register(mfuncs.achat) +MelleaSession.register(mfuncs.avalidate, set_context=False) +MelleaSession.register(mfuncs.aquery) +MelleaSession.register(mfuncs.atransform) diff --git a/test/stdlib_basics/test_funcs.py b/test/stdlib_basics/test_functional.py similarity index 94% rename from test/stdlib_basics/test_funcs.py rename to test/stdlib_basics/test_functional.py index 4e99afc5..924b5b8c 100644 --- a/test/stdlib_basics/test_funcs.py +++ b/test/stdlib_basics/test_functional.py @@ -5,7 +5,7 @@ from mellea.backends.types import ModelOption from mellea.stdlib.base import ModelOutputThunk from mellea.stdlib.chat import Message -from mellea.stdlib.funcs import instruct, aact, avalidate, ainstruct +from mellea.stdlib.functional import instruct, aact, avalidate, ainstruct from mellea.stdlib.requirement import req from mellea.stdlib.session import start_session @@ -77,4 +77,4 @@ async def test_avalidate(m_session): if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__])