diff --git a/.github/actions/free-disk-space/action.yml b/.github/actions/free-disk-space/action.yml new file mode 100644 index 00000000..e2af3e9d --- /dev/null +++ b/.github/actions/free-disk-space/action.yml @@ -0,0 +1,27 @@ +name: 'Free Disk Space' +description: 'Frees disk space on the runner' +runs: + using: "composite" + steps: + - name: Print disk space before cleanup + run: | + df -h + shell: bash + - name: Free Disk Space Linux + if: runner.os == 'Linux' + run: | + sudo docker rmi "$(docker image ls -aq)" >/dev/null 2>&1 || true + sudo rm -rf \ + /usr/share/dotnet /usr/local/lib/android /opt/ghc \ + /usr/local/share/powershell /usr/share/swift /usr/local/.ghcup \ + /usr/lib/jvm || true + sudo apt install aptitude -y >/dev/null 2>&1 + sudo aptitude purge '~n ^mysql' -f -y >/dev/null 2>&1 + sudo aptitude purge '~n ^dotnet' -f -y >/dev/null 2>&1 + sudo apt-get autoremove -y >/dev/null 2>&1 + sudo apt-get autoclean -y >/dev/null 2>&1 + shell: bash + - name: Print disk space after cleanup + run: | + df -h + shell: bash \ No newline at end of file diff --git a/.github/workflows/quality.yml b/.github/workflows/quality.yml index c6dbd8e9..69dec08d 100644 --- a/.github/workflows/quality.yml +++ b/.github/workflows/quality.yml @@ -21,6 +21,8 @@ jobs: python-version: ['3.10', '3.11', '3.12'] # Need to add 3.13 once we resolve outlines issues. steps: - uses: actions/checkout@v4 + - name: Free disk space + uses: ./.github/actions/free-disk-space - name: Install uv and set the python version uses: astral-sh/setup-uv@v5 with: diff --git a/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py b/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py index 56ca36c1..6ee529b2 100644 --- a/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py +++ b/cli/decompose/prompt_modules/constraint_extractor/_constraint_extractor.py @@ -4,6 +4,7 @@ from mellea import MelleaSession from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock from mellea.stdlib.instruction import Instruction from .._prompt_modules import PromptModule, PromptModuleString @@ -114,9 +115,8 @@ def generate( # type: ignore[override] instruction = Instruction(description=user_prompt, prefix=system_prompt) try: - gen_result = mellea_session.backend.generate_from_context( + gen_result = mellea_session.act( action=instruction, - ctx=mellea_session.ctx, model_options={ ModelOption.TEMPERATURE: 0, ModelOption.MAX_NEW_TOKENS: max_new_tokens, diff --git a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py index 9efe1723..36cb866e 100644 --- a/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py +++ b/cli/decompose/prompt_modules/subtask_constraint_assign/_subtask_constraint_assign.py @@ -216,9 +216,8 @@ def generate( # type: ignore[override] instruction = Instruction(description=user_prompt, prefix=system_prompt) try: - gen_result = mellea_session.backend.generate_from_context( + gen_result = mellea_session.act( action=instruction, - ctx=mellea_session.ctx, model_options={ ModelOption.TEMPERATURE: 0, ModelOption.MAX_NEW_TOKENS: max_new_tokens, diff --git a/cli/decompose/prompt_modules/subtask_list/_subtask_list.py b/cli/decompose/prompt_modules/subtask_list/_subtask_list.py index 216dc1c5..4f00b257 100644 --- a/cli/decompose/prompt_modules/subtask_list/_subtask_list.py +++ b/cli/decompose/prompt_modules/subtask_list/_subtask_list.py @@ -144,9 +144,8 @@ def generate( instruction = Instruction(description=user_prompt, prefix=system_prompt) try: - gen_result = mellea_session.backend.generate_from_context( + gen_result = mellea_session.act( action=instruction, - ctx=mellea_session.ctx, model_options={ ModelOption.TEMPERATURE: 0, ModelOption.MAX_NEW_TOKENS: max_new_tokens, diff --git a/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py b/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py index 47c6a226..734282b4 100644 --- a/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py +++ b/cli/decompose/prompt_modules/subtask_prompt_generator/_subtask_prompt_generator.py @@ -218,9 +218,8 @@ def generate( # type: ignore[override] instruction = Instruction(description=user_prompt, prefix=system_prompt) try: - gen_result = mellea_session.backend.generate_from_context( + gen_result = mellea_session.act( action=instruction, - ctx=mellea_session.ctx, model_options={ ModelOption.TEMPERATURE: 0, ModelOption.MAX_NEW_TOKENS: max_new_tokens, diff --git a/docs/examples/mify/rich_document_advanced.py b/docs/examples/mify/rich_document_advanced.py index 58e6814b..186ddd90 100644 --- a/docs/examples/mify/rich_document_advanced.py +++ b/docs/examples/mify/rich_document_advanced.py @@ -44,7 +44,7 @@ # Note: Because the template for a RichDocument just outputs it as markdown, # the model doesn't really know what to do with it in this context. However, this # is a useful pattern if you want to use a component with a specified template. -thunk = m.backend.generate_from_context(action=rd, ctx=m.ctx) +thunk = m.act(action=rd) print(thunk.value) # > - user: What is the primary goal of the GLTR tool... # 5. The class is opinionated and outputs the document as markdown to the model (like in the initial example). @@ -87,7 +87,7 @@ def from_document_file( rds.format_for_llm().args ) # > {'titles': ['GLTR: Statistical Detection and Visualization of Generated Text', 'Abstract', ..., 'References']} -thunk = m.backend.generate_from_context(action=rds, ctx=m.ctx) +thunk = m.act(action=rds) print(thunk.value) # > The document appears to be an academic research paper... # 6. We can also pass this document as grounding context to an instruction. diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index c7d8e1c3..5711aa73 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -41,10 +41,9 @@ def generate_from_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: # i.e., ContextDiff - """Generates a model output from a context. May not mutate the context. + """Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request. Args: action: The last item of the context should be passed in as an `action` instead of as part of the `ctx`. See `docs/dev/generate_signature_decisions.md`. diff --git a/mellea/backends/aloras/__init__.py b/mellea/backends/aloras/__init__.py index 9a825f4c..ae7b37b2 100644 --- a/mellea/backends/aloras/__init__.py +++ b/mellea/backends/aloras/__init__.py @@ -2,7 +2,7 @@ import abc -from mellea.stdlib.base import CBlock +from mellea.stdlib.base import CBlock, ModelOutputThunk class Alora(abc.ABC): @@ -24,8 +24,8 @@ def __init__(self, name: str): self.name: str = name @abc.abstractmethod - def generate_using_strings(self, *args, **kwargs) -> str: - """Generates from the ALora using raw strings as the interface for both inputs and outputs. + def generate_using_strings(self, *args, **kwargs) -> ModelOutputThunk: + """Generates from the ALora using raw strings as the interface for inputs. In most cases, must be run from a running event loop. This has a generic signature because each aLoRA has different parameters depending on its functionality and how it gets called. """ diff --git a/mellea/backends/aloras/huggingface/granite_aloras.py b/mellea/backends/aloras/huggingface/granite_aloras.py index 87dab75c..2e1e7284 100644 --- a/mellea/backends/aloras/huggingface/granite_aloras.py +++ b/mellea/backends/aloras/huggingface/granite_aloras.py @@ -1,12 +1,17 @@ """Huggingface implementations for IBM's "starter pack" of Activated LoRAs.""" +import asyncio +import functools from copy import deepcopy import torch +from transformers.generation.utils import GenerateDecoderOnlyOutput from mellea.backends.huggingface import HFAlora, HFAloraCacheInfo, LocalHFBackend from mellea.backends.types import ModelOption +from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import GenerateType, ModelOutputThunk class HFConstraintAlora(HFAlora): @@ -44,9 +49,14 @@ def __init__( self._logger = FancyLogger.get_logger() def generate_using_strings( - self, input: str, response: str, constraint: str, force_yn: bool = True - ) -> str: - """Generates a constraint response from the ALora.""" + self, + input: str, + response: str, + constraint: str, + force_yn: bool = True, + stream: bool = False, + ) -> ModelOutputThunk: + """Generates a constraint response from the ALora. Must be run in a running event loop.""" assert self._backend.alora_model is not None # Go ahead and do runtime type-checking because passing CBlocks into this function is a common error. assert type(input) is str @@ -54,21 +64,84 @@ def generate_using_strings( assert type(constraint) is str self._backend.alora_model.set_adapter(self.name) cache_hit = self._backend.cache_get(response) + + if stream: + self._logger.warning( + "`HFConstraintAlora` cannot stream output; defaulting to non-streaming approach." + ) + + generate_kwargs = {} if cache_hit: self._logger.debug( f"using cache for alora {self.__class__} and response '{response}'" ) - return self._generate_using_cache(cache_hit, constraint, force_yn) + generate_kwargs["past_key_values"] = deepcopy(cache_hit.kv_cache) + input_combined = self._generate_using_cache(cache_hit, constraint, force_yn) + else: self._logger.debug( f"not using cache for alora {self.__class__} and response '{response}'" ) - return self._generate_not_using_cache(input, response, constraint, force_yn) + input_combined = self._generate_not_using_cache( + input, response, constraint, force_yn + ) + + if not self._include_constraint_in_alora_offset: + alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1] + else: + # Get the constraint tokens separately so that we can calculate the alora offsets. + constraint_tokens = self._backend._tokenizer( + self._constraint_prompt.format(constraint), return_tensors="pt" + ).to(self._backend._device) + + alora_offsets = [ + constraint_tokens["input_ids"].shape[1] + + self._generation_prompt_tokens["input_ids"].shape[1] + - 2 + ] + + chat_response = asyncio.to_thread( + self._backend.alora_model.generate, + input_combined["input_ids"].to(self._backend._device), + attention_mask=input_combined["attention_mask"].to(self._backend._device), + max_new_tokens=1, + return_dict_in_generate=True, + alora_offsets=alora_offsets, + output_scores=True, + **generate_kwargs, + ) + + output = ModelOutputThunk(None) + output._meta["alora_name"] = self.name + + output._process = functools.partial( + processing, + backend=self._backend, + force_yn=force_yn, + gen_prompt=self._generation_prompt, + ) + output._post_process = functools.partial(post_processing, backend=self._backend) + + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. + + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) # type: ignore + ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present. + raise e + + return output def _generate_using_cache( self, cache_hit: HFAloraCacheInfo, constraint: str, force_yn: bool - ) -> str: - assert self._backend.alora_model is not None + ) -> dict: + """Returns the input object used for generation.""" # Must tokenize the constraint here since the requirement isn't known at initialization. constraint_tokens = self._backend._tokenizer( @@ -94,62 +167,16 @@ def _generate_using_cache( ), } - if not self._include_constraint_in_alora_offset: - alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1] - else: - alora_offsets = [ - constraint_tokens["input_ids"].shape[1] - + self._generation_prompt_tokens["input_ids"].shape[1] - - 2 - ] self._logger.debug( f"Prompt for cached aLoRA({self.name}):\n {self._backend._tokenizer.decode(input_combined['input_ids'][0])}" ) - if force_yn: - output = self._backend.alora_model.generate( - input_combined["input_ids"].to(self._backend._device), - attention_mask=input_combined["attention_mask"].to( - self._backend._device - ), - max_new_tokens=1, - past_key_values=deepcopy(cache_hit.kv_cache), - return_dict_in_generate=True, - alora_offsets=alora_offsets, - output_scores=True, - ) - last_logits = output.scores[-1].squeeze(0) - token_Y = self._backend._tokenizer("Y", add_special_tokens=False)[ - "input_ids" - ][0] - token_N = self._backend._tokenizer("N", add_special_tokens=False)[ - "input_ids" - ][0] - logit_Y = last_logits[token_Y].item() - logit_N = last_logits[token_N].item() - return "Y" if logit_Y > logit_N else "N" - else: - output = self._backend.alora_model.generate( - input_combined["input_ids"].to(self._backend._device), - attention_mask=input_combined["attention_mask"].to( - self._backend._device - ), - max_new_tokens=1, - past_key_values=deepcopy(cache_hit.kv_cache), - return_dict_in_generate=True, - alora_offsets=alora_offsets, - ) - output_text = self._backend._tokenizer.decode(output.sequences[0]) - assert output_text[-1] in ["Y", "N"], ( - f"The constraint model card states the the Requirement Checker model will respond with 'Y' or 'N', but found: {output_text}" - ) - constraint_satisfied = output_text.split(self._generation_prompt)[-1] - return constraint_satisfied[0] + return input_combined def _generate_not_using_cache( self, input: str, response: str, constraint: str, force_yn: bool - ) -> str: - assert self._backend.alora_model is not None + ) -> dict: + """Returns the input object used for generation.""" # Params aren't needed when just getting the backend args. backend_model_opts = self._backend._simplify_and_merge(None) @@ -185,58 +212,43 @@ def _generate_not_using_cache( ), } - if not self._include_constraint_in_alora_offset: - alora_offsets = [self._generation_prompt_tokens["input_ids"].shape[1] - 1] - else: - # Get the constraint tokens separately so that we can calculate the alora offsets. - constraint_tokens = self._backend._tokenizer( - self._constraint_prompt.format(constraint), return_tensors="pt" - ).to(self._backend._device) - - alora_offsets = [ - constraint_tokens["input_ids"].shape[1] - + self._generation_prompt_tokens["input_ids"].shape[1] - - 2 - ] - self._logger.debug( f"Prompt for non-cached aLoRA({self.name}):\n{self._backend._tokenizer.decode(input_combined['input_ids'][0])}" ) - if force_yn: - output = self._backend.alora_model.generate( - input_combined["input_ids"].to(self._backend._device), - attention_mask=input_combined["attention_mask"].to( - self._backend._device - ), - max_new_tokens=1, - return_dict_in_generate=True, - alora_offsets=alora_offsets, - output_scores=True, - ) - last_logits = output.scores[-1].squeeze(0) - token_Y = self._backend._tokenizer("Y", add_special_tokens=False)[ - "input_ids" - ][0] - token_N = self._backend._tokenizer("N", add_special_tokens=False)[ - "input_ids" - ][0] - logit_Y = last_logits[token_Y].item() - logit_N = last_logits[token_N].item() - return "Y" if logit_Y > logit_N else "N" - else: - output = self._backend.alora_model.generate( - input_combined["input_ids"].to(self._backend._device), - attention_mask=input_combined["attention_mask"].to( - self._backend._device - ), - max_new_tokens=1, - return_dict_in_generate=True, - alora_offsets=alora_offsets, - ) - output_text = self._backend._tokenizer.decode(output.sequences[0]) - constraint_satisfied = output_text.split(self._generation_prompt)[-1] - return constraint_satisfied[0] + return input_combined + + +async def processing( + mot: ModelOutputThunk, + chunk: GenerateDecoderOnlyOutput, + backend: LocalHFBackend, + force_yn: bool, + gen_prompt: str, +): + if mot._underlying_value is None: + mot._underlying_value = "" + + # Don't support async for HFConstraintAlora. Means we can process the output here. + assert isinstance(chunk, GenerateDecoderOnlyOutput) + + if force_yn: + last_logits = chunk.scores[-1].squeeze(0) # type: ignore + token_Y = backend._tokenizer("Y", add_special_tokens=False)["input_ids"][0] # type: ignore + token_N = backend._tokenizer("N", add_special_tokens=False)["input_ids"][0] # type: ignore + logit_Y = last_logits[token_Y].item() + logit_N = last_logits[token_N].item() + mot._underlying_value = "Y" if logit_Y > logit_N else "N" + else: + output_text = backend._tokenizer.decode(chunk.sequences[0]) + constraint_satisfied = output_text.split(gen_prompt)[-1] + mot._underlying_value = constraint_satisfied[ + 0 + ] # Grab the first char of the str. + + +async def post_processing(mot: ModelOutputThunk, backend: LocalHFBackend): + backend.formatter.parse(mot._action, mot) # type: ignore def add_granite_aloras(backend: LocalHFBackend): diff --git a/mellea/backends/aloras/openai/granite_aloras.py b/mellea/backends/aloras/openai/granite_aloras.py index 23dcc124..a6d17172 100644 --- a/mellea/backends/aloras/openai/granite_aloras.py +++ b/mellea/backends/aloras/openai/granite_aloras.py @@ -1,11 +1,19 @@ """OpenAI implementations for IBM's "starter pack" of Activated LoRAs.""" +import asyncio +import functools +from collections.abc import Coroutine +from typing import Any + import openai +from openai.types.completion import Completion from mellea.backends.aloras import Alora from mellea.backends.openai import OpenAIAlora, OpenAIBackend from mellea.backends.types import ModelOption +from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import GenerateType, ModelOutputThunk class OpenAIConstraintAlora(OpenAIAlora): @@ -21,9 +29,14 @@ def __init__( self._logger = FancyLogger.get_logger() def generate_using_strings( - self, input: str, response: str, constraint: str, force_yn: bool = True - ) -> str: - """Generates a constraint response from the ALora.""" + self, + input: str, + response: str, + constraint: str, + force_yn: bool = True, + stream: bool = False, + ) -> ModelOutputThunk: + """Generates a constraint response from the ALora. Must be run in a running event loop.""" # Go ahead and do runtime type-checking because passing CBlocks into this function is a common error. assert type(input) is str assert type(response) is str @@ -40,38 +53,65 @@ def generate_using_strings( ] prompt = self._backend.apply_chat_template(chat) - prompt += f"\nRequirement: {constraint}<|end_of_text|>\n" + prompt += f"\nRequirement: {constraint}<|end_of_text|>\n" # type: ignore prompt += self._generation_prompt self._logger.debug(f"Prompt for non-cached aLoRA({self.name}):\n{prompt}") + force_yn_args = {} if force_yn: assert hasattr(self._backend, "_tokenizer") token_Y = self._backend._tokenizer("Y", add_special_tokens=False)[ "input_ids" - ][0] + ][0] # type: ignore token_N = self._backend._tokenizer("N", add_special_tokens=False)[ "input_ids" - ][0] - return ( - self._backend._client.completions.create( - model=self.name, - prompt=prompt, - max_tokens=1, - n=1, - logit_bias={str(token_Y): 100, str(token_N): 100}, - ) - .choices[0] - .text - ) - else: - return ( - self._backend._client.completions.create( - model=self.name, prompt=prompt, max_tokens=1, n=1 - ) - .choices[0] - .text + ][0] # type: ignore + + force_yn_args["logit_bias"] = {str(token_Y): 100, str(token_N): 100} + + chat_response: Coroutine[ + Any, Any, openai.AsyncStream[Completion] | Completion + ] = self._backend._async_client.completions.create( + model=self.name, + prompt=prompt, + max_tokens=1, + n=1, + stream=stream, + **force_yn_args, + ) # type: ignore + + output = ModelOutputThunk(None) + output._meta["alora_name"] = self.name + + output._process = processing + output._post_process = functools.partial(post_processing, backend=self._backend) + + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. + + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present + raise e + + return output + + +async def processing(mot: ModelOutputThunk, chunk: Completion): + if mot._underlying_value is None: + mot._underlying_value = "" + mot._underlying_value += chunk.choices[0].text + + +async def post_processing(backend: OpenAIBackend, mot: ModelOutputThunk): + backend.formatter.parse(mot._action, mot) # type: ignore def add_granite_aloras(backend: OpenAIBackend): diff --git a/mellea/backends/dummy.py b/mellea/backends/dummy.py index 86bc1586..c2216c6e 100644 --- a/mellea/backends/dummy.py +++ b/mellea/backends/dummy.py @@ -23,7 +23,6 @@ def generate_from_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: """See constructor for an exmplanation of how DummyBackends work.""" diff --git a/mellea/backends/formatter.py b/mellea/backends/formatter.py index f2eba1a6..0b296eff 100644 --- a/mellea/backends/formatter.py +++ b/mellea/backends/formatter.py @@ -44,7 +44,9 @@ def print_context(self, ctx: Context) -> str: def parse( self, source_component: Component | CBlock, result: ModelOutputThunk ) -> ModelOutputThunk: - """Parses the output from a model.""" + """Parses the output from a model and sets the parsed_repr of the result ModelOutputThunk. + + Returns the ModelOutputThunk that was passed in.""" ... def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]: diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 261ebfd2..ef50eb40 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -6,17 +6,20 @@ from __future__ import annotations import abc +import asyncio import dataclasses import datetime +import functools import inspect import json -from collections.abc import Callable +from collections.abc import Callable, Coroutine from typing import TYPE_CHECKING, Any import outlines import outlines_core import torch from transformers import ( + AsyncTextIteratorStreamer, AutoModelForCausalLM, AutoTokenizer, DynamicCache, @@ -24,6 +27,7 @@ PreTrainedTokenizer, set_seed, ) +from transformers.generation.utils import GenerateDecoderOnlyOutput from mellea.backends import BaseModelSubclass from mellea.backends.aloras import Alora, AloraBackendMixin @@ -38,12 +42,14 @@ parse_tools, ) from mellea.backends.types import ModelOption +from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( CBlock, Component, Context, GenerateLog, + GenerateType, ModelOutputThunk, ModelToolCall, ) @@ -117,6 +123,7 @@ def __init__( "max_new_tokens": ModelOption.MAX_NEW_TOKENS, "seed": ModelOption.SEED, "tools": ModelOption.TOOLS, + "stream": ModelOption.STREAM, } # A mapping of Mellea specific ModelOptions to the specific names for this backend. @@ -183,7 +190,6 @@ def generate_from_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: """Generate using the huggingface model.""" @@ -203,19 +209,10 @@ def generate_from_context( reroute_to_alora = True if reroute_to_alora: return self._generate_from_context_alora( - action, - ctx, - format=format, - model_options=model_opts, - generate_logs=generate_logs, + action, ctx, format=format, model_options=model_opts ) return self._generate_from_context_standard( - action, - ctx, - format=format, - model_options=model_opts, - generate_logs=generate_logs, - tool_calls=tool_calls, + action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls ) def _generate_from_context_alora( @@ -225,7 +222,6 @@ def _generate_from_context_alora( *, format: type[BaseModelSubclass] | None = None, model_options: dict[str, Any], - generate_logs: list[GenerateLog] | None = None, ) -> ModelOutputThunk: match action: case ALoraRequirement(): @@ -248,29 +244,23 @@ def _generate_from_context_alora( assert type(user_message) is str assert type(assistant_message) is str assert format is None, "Structured outputs are not supported by ALoRAs." + alora_output = alora_for_this_request.generate_using_strings( input=user_message, response=assistant_message, constraint=action.description, # type: ignore + stream=model_options.get(ModelOption.STREAM, False), ) - if generate_logs is not None: - mess = user_message.replace("\n", "\\n") - assist_mess = assistant_message.replace("\n", "\\n") - log = GenerateLog( - prompt=f"aLora(name='{alora_for_this_request.name}', input='{mess}', response='{assist_mess}', constraint='{action.description}') ", # type: ignore - result=ModelOutputThunk(alora_output), - model_options=model_options, - date=datetime.datetime.now(), - ) - generate_logs.append(log) + # The alora function doesn't set up all the fields. + alora_output._context = linearized_ctx + alora_output._action = action + alora_output._model_options = model_options - return self.formatter.parse( - action, - ModelOutputThunk( - alora_output, meta={"alora_name": alora_for_this_request.name} - ), - ) + # TODO: Figure out what info we want to populate for aloras here. + alora_output._generate_log = GenerateLog() + + return alora_output def _generate_from_context_standard( self, @@ -279,13 +269,11 @@ def _generate_from_context_standard( *, format: type[BaseModelSubclass] | None = None, model_options: dict[str, Any], - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: # Construct input. # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. # Otherwise, we will linearize the context and treat it as a raw input. - decoded_result: str | None = None if ctx.is_chat_context: linearized_ctx = ctx.render_for_generation() assert linearized_ctx is not None, ( @@ -346,90 +334,186 @@ def _generate_from_context_standard( **self._make_backend_specific_and_remove(model_options), ).to(self._device) # type: ignore - if format is None: - chat_output = self._model.generate( # type: ignore - input_ids, - return_dict_in_generate=True, - output_scores=True, - **self._make_backend_specific_and_remove(model_options), - ) # type: ignore - - else: + format_kwargs = {} + if format: # outlines.generate.json always parses the resulting json into a python dict. # We however want to keep it as a json string for later storing it in ModelOutputThunk schema: dict[str, Any] = format.model_json_schema() schema_json: str = json.dumps(schema) - regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( + regex_str: str = outlines_core.fsm.json_schema.build_regex_from_schema( # type: ignore schema_json ) from outlines.models.transformers import TransformerTokenizer - from outlines.processors import RegexLogitsProcessor + from outlines.processors.structured import RegexLogitsProcessor from transformers import LogitsProcessorList - chat_output = self._model.generate( # type: ignore - input_ids, - return_dict_in_generate=True, - output_scores=True, - logits_processor=LogitsProcessorList( - [ - RegexLogitsProcessor( - regex_str, - tokenizer=TransformerTokenizer(self._tokenizer), - ) - ] - ), - **self._make_backend_specific_and_remove(model_options), + format_kwargs["logits_processor"] = LogitsProcessorList( + [ + RegexLogitsProcessor( + regex_str, tokenizer=TransformerTokenizer(self._tokenizer) + ) + ] ) - decoded_result = self._tokenizer.decode( - chat_output.sequences[0, input_ids.shape[1] :], skip_special_tokens=True + streaming_kwargs = {} + streamer = None + stream = model_options.get(ModelOption.STREAM, False) + if stream: + try: + # HuggingFace uses a streaming interface that you pass to the generate call. + # Must be called from a running event loop. This should always be the case given the same + # requirement of the ._generate function below. + streamer = AsyncTextIteratorStreamer( + self._tokenizer, # type: ignore + skip_prompt=True, + skip_special_tokens=True, + ) + streaming_kwargs["streamer"] = streamer + except RuntimeError as e: + # Most likely cause is creating this object without an event loop present. + raise e + + # Create a separate thread to handle the processing. Make it awaitable + # for non-streaming cases and to get the final output. + # Details: https://huggingface.co/docs/transformers/en/internal/generation_utils#transformers.AsyncTextIteratorStreamer + chat_response = asyncio.to_thread( + self._model.generate, # type: ignore + input_ids, + return_dict_in_generate=True, + output_scores=True, + **self._make_backend_specific_and_remove(model_options), + **streaming_kwargs, # type: ignore + **format_kwargs, # type: ignore ) - # Add an entry to the cache for ALora reuse. - if self._use_caches: - output_complete = chat_output.sequences[0] - cache: DynamicCache = chat_output.past_key_values + output = ModelOutputThunk(None) + output._context = linearized_ctx + output._action = action + output._model_options = model_options + + # Processing functions only pass the ModelOutputThunk (and current chunk of response). Bind the other vars necessary for + # each processing step. + output._process = functools.partial(self.processing, input_ids=input_ids) + output._post_process = functools.partial( + self.post_processing, + conversation=ctx_as_conversation, + input_ids=input_ids, + tool_calls=tool_calls, + tools=tools, + seed=seed, + ) - cache_info = HFAloraCacheInfo( - kv_cache=cache, - merged_token_ids=output_complete, - merged_attention=torch.ones_like(output_complete).to(self._device), - q_end=len(input_ids[0]), + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. + + response: AsyncTextIteratorStreamer | Coroutine = chat_response + if stream and streamer is not None: + # For streaming, we want to pass the AsyncIterator to the function. Unlike other backends, + # this isn't returned by the chat_response coroutine. So we handle it here. + response = streamer + + # Since the async iterator isn't returned by the chat_response coroutine, we have to create a separate + # task for it here so that it runs in the background. Attach it to the ModelOutputThunk. + output._generate_extra = asyncio.create_task(chat_response) + + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(response, output._async_queue) # type: ignore ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present. + raise e + + return output - assert decoded_result is not None - self.cache_put(decoded_result, cache_info) else: raise Exception("Does not yet support non-chat contexts.") - assert decoded_result is not None + async def processing( + self, mot: ModelOutputThunk, chunk: str | GenerateDecoderOnlyOutput, input_ids + ): + """Process the returned chunks or the complete response.""" + if mot._underlying_value is None: + mot._underlying_value = "" + + # Because we use the AsyncTextIteratorStreamer, streaming responses are of type str; + # and already decoded. + if isinstance(chunk, str): + mot._underlying_value += chunk + else: + # Otherwise, it's a non-streaming request. Decode it here. + mot._meta["hf_output"] = chunk + mot._underlying_value += self._tokenizer.decode( + chunk.sequences[0, input_ids.shape[1] :], skip_special_tokens=True + ) + + async def post_processing( + self, + mot: ModelOutputThunk, + conversation: list[dict], + tool_calls: bool, + tools: dict[str, Callable], + seed, + input_ids, + ): + """Called when generation is done.""" + if mot._meta.get("hf_output", None) is None: + if mot._generate_extra is not None: + full_output = await mot._generate_extra + assert isinstance(full_output, GenerateDecoderOnlyOutput) + mot._meta["hf_output"] = full_output + + # The ModelOutputThunk must be computed by this point. + assert mot.value is not None + + # Add an entry to the cache for ALora reuse. + if self._use_caches: + output_complete = mot._meta["hf_output"].sequences[0] + cache: DynamicCache = mot._meta["hf_output"].past_key_values # type: ignore + + cache_info = HFAloraCacheInfo( + kv_cache=cache, + merged_token_ids=output_complete, + merged_attention=torch.ones_like(output_complete).to(self._device), + q_end=len(input_ids[0]), # type: ignore + ) - result = ModelOutputThunk(value=decoded_result) + self.cache_put(mot.value, cache_info) - # Only scan for tools if we are not doing structured decoding and tool calls were provided to the model. + # Only scan for tools if we are not doing structured output and tool calls were provided to the model. if format is None and tool_calls: - result.tool_calls = self._extract_model_tool_requests(tools, decoded_result) + mot.tool_calls = self._extract_model_tool_requests(tools, mot.value) - parsed_result = self.formatter.parse(action, result) - if generate_logs is not None: - assert isinstance(generate_logs, list) - generate_log = GenerateLog() - generate_log.prompt = ctx_as_conversation - generate_log.backend = f"hf::{self.model_id!s}" - generate_log.model_options = model_options - generate_log.date = datetime.datetime.now() - generate_log.model_output = decoded_result - generate_log.extra = { - "format": format, - "tools_available": tools, - "tools_called": result.tool_calls, - "seed": seed, - } - generate_log.action = action - generate_log.result = parsed_result - generate_logs.append(generate_log) - return parsed_result + assert mot._action is not None, ( + "ModelOutputThunks should have their action assigned during generation" + ) + assert mot._model_options is not None, ( + "ModelOutputThunks should have their model_opts assigned during generation" + ) + + self.formatter.parse(mot._action, mot) + + # Generate the log for this ModelOutputThunk. + generate_log = GenerateLog() + generate_log.prompt = conversation + generate_log.backend = f"hf::{self.model_id!s}" + generate_log.model_options = mot._model_options + generate_log.date = datetime.datetime.now() + generate_log.model_output = mot.value + generate_log.extra = { + "format": format, + "tools_available": tools, + "tools_called": mot.tool_calls, + "seed": seed, + } + generate_log.action = mot._action + generate_log.result = mot + + mot._generate_log = generate_log def _generate_from_raw( self, diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index b330bbcd..23ea446e 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -1,8 +1,11 @@ """A generic LiteLLM compatible backend that wraps around the openai python sdk.""" +import asyncio import datetime +import functools import json -from collections.abc import Callable +import os +from collections.abc import Callable, Coroutine from typing import Any import litellm # type: ignore @@ -19,12 +22,18 @@ convert_tools_to_json, ) from mellea.backends.types import ModelOption +from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.openai_compatible_helpers import ( + chat_completion_delta_merge, + extract_model_tool_requests, +) from mellea.stdlib.base import ( CBlock, Component, Context, GenerateLog, + GenerateType, ModelOutputThunk, ModelToolCall, ) @@ -74,12 +83,13 @@ def __init__( # users should only be specifying a single one in their request. self.to_mellea_model_opts_map = { "system": ModelOption.SYSTEM_PROMPT, - "reasoning_effort": ModelOption.THINKING, # TODO: JAL; see which of these are actually extracted... + "reasoning_effort": ModelOption.THINKING, "seed": ModelOption.SEED, "max_completion_tokens": ModelOption.MAX_NEW_TOKENS, "max_tokens": ModelOption.MAX_NEW_TOKENS, "tools": ModelOption.TOOLS, "functions": ModelOption.TOOLS, + "stream": ModelOption.STREAM, } # A mapping of Mellea specific ModelOptions to the specific names for this backend. @@ -90,6 +100,7 @@ def __init__( self.from_mellea_model_opts_map = { ModelOption.SEED: "seed", ModelOption.MAX_NEW_TOKENS: "max_completion_tokens", + ModelOption.STREAM: "stream", } def generate_from_context( @@ -99,7 +110,6 @@ def generate_from_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ): """See `generate_from_chat_context`.""" @@ -111,7 +121,6 @@ def generate_from_context( ctx, format=format, model_options=model_options, - generate_logs=generate_logs, tool_calls=tool_calls, ) @@ -204,7 +213,6 @@ def _generate_from_chat_context_standard( format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: model_opts = self._simplify_and_merge(model_options) @@ -259,7 +267,9 @@ def _generate_from_chat_context_standard( model_specific_options = self._make_backend_specific_and_remove(model_opts) - chat_response: litellm.ModelResponse = litellm.completion( + chat_response: Coroutine[ + Any, Any, litellm.ModelResponse | litellm.ModelResponseStream # type: ignore + ] = litellm.acompletion( model=self._model_id, messages=conversation, tools=formatted_tools, @@ -269,39 +279,147 @@ def _generate_from_chat_context_standard( **model_specific_options, ) - choice_0 = chat_response.choices[0] - assert isinstance(choice_0, litellm.utils.Choices), ( - "Only works for non-streaming response for now" - ) - result = ModelOutputThunk( - value=choice_0.message.content, - meta={ - "litellm_chat_response": chat_response.choices[0].model_dump() - }, # NOTE: Using model dump here to comply with `TemplateFormatter` - tool_calls=self._extract_model_tool_requests(tools, chat_response), + output = ModelOutputThunk(None) + output._context = linearized_context + output._action = action + output._model_options = model_opts + + # Processing functions only pass the ModelOutputThunk (and current chunk of response). Bind the other vars necessary for + # each processing step. + output._process = self.processing + output._post_process = functools.partial( + self.post_processing, + conversation=conversation, + tools=tools, + thinking=thinking, ) - parsed_result = self.formatter.parse(source_component=action, result=result) - - if generate_logs is not None: - assert isinstance(generate_logs, list) - generate_log = GenerateLog() - generate_log.prompt = conversation - generate_log.backend = f"litellm::{self.model_id!s}" - generate_log.model_options = model_specific_options - generate_log.date = datetime.datetime.now() - generate_log.model_output = chat_response - generate_log.extra = { - "format": format, - "tools_available": tools, - "tools_called": result.tool_calls, - "seed": model_opts.get("seed", None), - } - generate_log.action = action - generate_log.result = parsed_result - generate_logs.append(generate_log) + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. + + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) + ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present + raise e + + return output + + async def processing( + self, + mot: ModelOutputThunk, + chunk: litellm.ModelResponse | litellm.ModelResponseStream, # type: ignore + ): + """Called during generation to add information from a single ModelResponse or a chunk / ModelResponseStream to the ModelOutputThunk. + + For LiteLLM, tool call parsing is handled in the post processing step.""" + if mot._thinking is None: + mot._thinking = "" + if mot._underlying_value is None: + mot._underlying_value = "" + + if isinstance(chunk, litellm.ModelResponse): # type: ignore + # choice should always be a `Choice`. There's some type weirdness going + # on with how litellm have defined the `.choices` list. + choice = chunk.choices[0] + assert isinstance(choice, litellm.Choices) + + message = choice.message + + # Sometimes a message doesn't actually have this field. + if hasattr(message, "reasoning_content"): + thinking_chunk = message.reasoning_content + if thinking_chunk is not None: + mot._thinking += thinking_chunk + + content_chunk = message.content + if content_chunk is not None: + mot._underlying_value += content_chunk + + mot._meta["litellm_chat_response"] = chunk.choices[0].model_dump() + + elif isinstance(chunk, litellm.ModelResponseStream): # type: ignore + message_delta = chunk.choices[0].delta + + # Sometimes a delta doesn't actually have this field. + if hasattr(message_delta, "reasoning_content"): + thinking_chunk = message_delta.reasoning_content + if thinking_chunk is not None: + mot._thinking += thinking_chunk + + content_chunk = message_delta.content + if content_chunk is not None: + mot._underlying_value += content_chunk + + if mot._meta.get("litellm_chat_response_streamed", None) is None: + mot._meta["litellm_chat_response_streamed"] = [] + mot._meta["litellm_chat_response_streamed"].append( + chunk.choices[0].model_dump() + ) + + async def post_processing( + self, + mot: ModelOutputThunk, + conversation: list[dict], + tools: dict[str, Callable], + thinking, + ): + """Called when generation is done.""" + # Reconstruct the chat_response from chunks if streamed. + streamed_chunks = mot._meta.get("litellm_chat_response_streamed", None) + if streamed_chunks is not None: + # Must handle ollama differently due to: https://github.com/BerriAI/litellm/issues/14579. + # Check that we are targeting ollama with the model_id prefix litellm uses. + separate_tools = False + if "ollama" in self._model_id.split("/")[0]: + separate_tools = True + mot._meta["litellm_chat_response"] = chat_completion_delta_merge( + streamed_chunks, force_all_tool_calls_separate=separate_tools + ) - return parsed_result + assert mot._action is not None, ( + "ModelOutputThunks should have their action assigned during generation" + ) + assert mot._model_options is not None, ( + "ModelOutputThunks should have their model_opts assigned during generation" + ) + + # OpenAI-like streamed responses potentially give you chunks of tool calls. + # As a result, we have to store data between calls and only then + # check for complete tool calls in the post_processing step. + tool_chunk = extract_model_tool_requests( + tools, mot._meta["litellm_chat_response"] + ) + if tool_chunk is not None: + if mot.tool_calls is None: + mot.tool_calls = {} + # Merge the tool_chunk dict. + for key, val in tool_chunk.items(): + mot.tool_calls[key] = val + + self.formatter.parse(mot._action, mot) + + # Generate the log for this ModelOutputThunk. + generate_log = GenerateLog() + generate_log.prompt = conversation + generate_log.backend = f"litellm::{self.model_id!s}" + generate_log.model_options = mot._model_options + generate_log.date = datetime.datetime.now() + generate_log.model_output = mot._meta["litellm_chat_response"] + generate_log.extra = { + "format": format, + "tools_available": tools, + "tools_called": mot.tool_calls, + "seed": thinking, + } + generate_log.action = mot._action + generate_log.result = mot + mot._generate_log = generate_log @staticmethod def _extract_tools( @@ -335,11 +453,13 @@ def _generate_from_raw( raise NotImplementedError("This method is not implemented yet.") def _extract_model_tool_requests( - self, tools: dict[str, Callable], chat_response: litellm.ModelResponse + self, + tools: dict[str, Callable], + chat_response: litellm.ModelResponse, # type: ignore ) -> dict[str, ModelToolCall] | None: model_tool_calls: dict[str, ModelToolCall] = {} choice_0 = chat_response.choices[0] - assert isinstance(choice_0, litellm.utils.Choices), ( + assert isinstance(choice_0, litellm.utils.Choices), ( # type: ignore "Only works for non-streaming response for now" ) calls = choice_0.message.tool_calls diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index 76907dc1..a4fe1324 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -2,7 +2,8 @@ import asyncio import datetime -from collections.abc import Callable +import functools +from collections.abc import AsyncIterator, Callable, Coroutine from typing import Any import ollama @@ -17,12 +18,14 @@ add_tools_from_model_options, ) from mellea.backends.types import ModelOption +from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( CBlock, Component, Context, GenerateLog, + GenerateType, ModelOutputThunk, ModelToolCall, TemplateRepresentation, @@ -84,6 +87,7 @@ def __init__( "num_predict": ModelOption.MAX_NEW_TOKENS, "seed": ModelOption.SEED, "tools": ModelOption.TOOLS, + "stream": ModelOption.STREAM, } # A mapping of Mellea specific ModelOptions to the specific names for this backend. @@ -232,7 +236,6 @@ def generate_from_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ): """See `generate_from_chat_context`.""" @@ -244,7 +247,6 @@ def generate_from_context( ctx, format=format, model_options=model_options, - generate_logs=generate_logs, tool_calls=tool_calls, ) @@ -255,13 +257,17 @@ def generate_from_chat_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: - """Generates a new completion from the provided Context using this backend's `Formatter`. + """Generates a ModelOutputThunk. The final value for this object can be awaited. + + The new completion is generated from the provided Context using this backend's `Formatter`. This implementation treats the `Context` as a chat history, and uses the `ollama.Client.chat()` interface to generate a completion. This will not always work, because sometimes we want to use non-chat models. + + Raises: + RuntimeError: If not called from a thread with a running event loop. """ model_opts = self._simplify_and_merge(model_options) @@ -310,46 +316,46 @@ def generate_from_chat_context( add_tools_from_context_actions(tools, [action]) FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") - # Generate a chat response from ollama, using the chat messages. - chat_response: ollama.ChatResponse = self._client.chat( + # Generate a chat response from ollama, using the chat messages. Can be either type since stream is passed as a model option. + chat_response: Coroutine[ + Any, Any, AsyncIterator[ollama.ChatResponse] | ollama.ChatResponse + ] = self._async_client.chat( model=self._get_ollama_model_id(), messages=conversation, tools=list(tools.values()), think=model_opts.get(ModelOption.THINKING, None), + stream=model_opts.get(ModelOption.STREAM, False), options=self._make_backend_specific_and_remove(model_opts), - stream=False, format=format.model_json_schema() if format is not None else None, ) # type: ignore - result = ModelOutputThunk( - value=chat_response.message.content, # For an ollama tool call, content will be an empty string. - meta={"chat_response": chat_response}, - tool_calls=self._extract_model_tool_requests(tools, chat_response), + output = ModelOutputThunk(None) + output._context = linearized_context + output._action = action + output._model_options = model_opts + + # Processing functions only pass the ModelOutputThunk (and current chunk of response). Bind the other vars necessary for + # each processing step. + output._process = functools.partial(self.processing, tools=tools) + output._post_process = functools.partial( + self.post_processing, conversation=conversation, tools=tools ) - formatted_result = self.formatter.parse(action, result) - - if generate_logs is not None: - # noinspection DuplicatedCode - assert isinstance(generate_logs, list) - generate_log = GenerateLog() - generate_log.prompt = conversation - generate_log.backend = f"ollama::{self.model_id!s}" - generate_log.model_options = model_opts - generate_log.date = datetime.datetime.now() - generate_log.model_output = chat_response - generate_log.extra = { - "format": format, - "thinking": model_opts.get(ModelOption.THINKING, None), - "tools_available": tools, - "tools_called": result.tool_calls, - "seed": model_opts.get(ModelOption.SEED, None), - } - generate_log.action = action - generate_log.result = formatted_result - generate_logs.append(generate_log) - - return formatted_result + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. + + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) + ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present + raise e + + return output def _generate_from_raw( self, @@ -456,3 +462,113 @@ def _extract_model_tool_requests( if len(model_tool_calls) > 0: return model_tool_calls return None + + async def processing( + self, + mot: ModelOutputThunk, + chunk: ollama.ChatResponse, + tools: dict[str, Callable], + ): + """Called during generation to add information from a single ChatResponse to the ModelOutputThunk.""" + if mot._thinking is None: + mot._thinking = "" + thinking_chunk = chunk.message.thinking + if thinking_chunk is not None: + mot._thinking += thinking_chunk + + if mot._underlying_value is None: + mot._underlying_value = "" + content_chunk = chunk.message.content + if content_chunk is not None: + mot._underlying_value += content_chunk + + tool_chunk = self._extract_model_tool_requests(tools, chunk) + if tool_chunk is not None: + # Only set tool_calls if there is one. + if mot.tool_calls is None: + mot.tool_calls = {} + + # Merge the tool_chunk dict. + for key, val in tool_chunk.items(): + mot.tool_calls[key] = val + + # Ollama responses are mostly self-contained. Merge chunks immediately. + chat_response_delta_merge(mot, chunk) + + async def post_processing( + self, + mot: ModelOutputThunk, + conversation: list[dict], + tools: dict[str, Callable], + ): + """Called when generation is done.""" + assert mot._action is not None, ( + "ModelOutputThunks should have their action assigned during generation" + ) + assert mot._model_options is not None, ( + "ModelOutputThunks should have their model_opts assigned during generation" + ) + self.formatter.parse(mot._action, mot) + + # Generate the log for this ModelOutputThunk. + generate_log = GenerateLog() + generate_log.prompt = conversation + generate_log.backend = f"ollama::{self._get_ollama_model_id()}" + generate_log.model_options = mot._model_options + generate_log.date = datetime.datetime.now() + generate_log.model_output = mot._meta["chat_response"] + generate_log.extra = { + "format": format, + "thinking": mot._model_options.get(ModelOption.THINKING, None), + "tools_available": tools, + "tools_called": mot.tool_calls, + "seed": mot._model_options.get(ModelOption.SEED, None), + } + generate_log.action = mot._action + generate_log.result = mot + + mot._generate_log = generate_log + mot._generate = None + + +def chat_response_delta_merge(mot: ModelOutputThunk, delta: ollama.ChatResponse): + if mot._meta.get("chat_response", None) is None: + mot._meta["chat_response"] = delta + return # Return early, no need to merge. + + merged: ollama.ChatResponse = mot._meta["chat_response"] + if not merged.done: + merged.done = delta.done + if merged.done_reason is None: + merged.done_reason = delta.done_reason + if merged.total_duration is None: + merged.total_duration = delta.total_duration + if merged.load_duration is None: + merged.load_duration = delta.load_duration + if merged.prompt_eval_count is None: + merged.prompt_eval_count = delta.prompt_eval_count + if merged.prompt_eval_duration is None: + merged.prompt_eval_duration = delta.prompt_eval_duration + if merged.eval_count is None: + merged.eval_count = delta.eval_count + + if merged.message.role == "": + merged.message.role = delta.message.role + + if merged.message.content is None: + merged.message.content = delta.message.content + elif delta.message.content is not None: + merged.message.content += delta.message.content + + if merged.message.thinking is None: + merged.message.thinking = delta.message.thinking + elif delta.message.thinking is not None: + merged.message.thinking += delta.message.thinking + + if merged.message.tool_calls is None: + merged.message.tool_calls = delta.message.tool_calls + elif delta.message.tool_calls is not None: + merged.message.tool_calls = [ + *merged.message.tool_calls, + *delta.message.tool_calls, + ] diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 918aafab..11646ad9 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -1,10 +1,12 @@ """A generic OpenAI compatible backend that wraps around the openai python sdk.""" import abc +import asyncio import datetime +import functools import inspect import json -from collections.abc import Callable +from collections.abc import Callable, Coroutine from enum import Enum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse @@ -13,6 +15,7 @@ import requests from huggingface_hub import snapshot_download from openai.types.chat import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from openai.types.completion import Completion import mellea.backends.model_ids as model_ids @@ -26,12 +29,18 @@ convert_tools_to_json, ) from mellea.backends.types import ModelOption +from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.openai_compatible_helpers import ( + chat_completion_delta_merge, + extract_model_tool_requests, +) from mellea.stdlib.base import ( CBlock, Component, Context, GenerateLog, + GenerateType, ModelOutputThunk, ModelToolCall, ) @@ -108,6 +117,7 @@ def __init__( "max_tokens": ModelOption.MAX_NEW_TOKENS, "tools": ModelOption.TOOLS, "functions": ModelOption.TOOLS, + "stream": ModelOption.STREAM, } # A mapping of Mellea specific ModelOptions to the specific names for this backend. # These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`. @@ -117,17 +127,20 @@ def __init__( self.from_mellea_model_opts_map_chats = { ModelOption.SEED: "seed", ModelOption.MAX_NEW_TOKENS: "max_completion_tokens", + ModelOption.STREAM: "stream", } # See notes above. self.to_mellea_model_opts_map_completions = { "seed": ModelOption.SEED, "max_tokens": ModelOption.MAX_NEW_TOKENS, + "stream": ModelOption.STREAM, } # See notes above. self.from_mellea_model_opts_map_completions = { ModelOption.SEED: "seed", ModelOption.MAX_NEW_TOKENS: "max_tokens", + ModelOption.STREAM: "stream", } self.default_to_constraint_checking_alora = default_to_constraint_checking_alora @@ -156,6 +169,10 @@ def __init__( self._client = openai.OpenAI( # type: ignore api_key=self._api_key, base_url=self._base_url, **openai_client_kwargs ) + self._async_client = openai.AsyncOpenAI( + api_key=self._api_key, base_url=self._base_url, **openai_client_kwargs + ) + # ALoras that have been loaded for this model. self._aloras: dict[str, OpenAIAlora] = {} @@ -254,7 +271,6 @@ def generate_from_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ): """See `generate_from_chat_context`.""" @@ -266,7 +282,6 @@ def generate_from_context( ctx, format=format, model_options=model_options, - generate_logs=generate_logs, tool_calls=tool_calls, ) @@ -278,7 +293,6 @@ def generate_from_chat_context( format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: """Generates a new completion from the provided Context using this backend's `Formatter`.""" @@ -302,7 +316,6 @@ def generate_from_chat_context( ctx, format=format, model_options=model_options, - generate_logs=generate_logs, tool_calls=tool_calls, ) @@ -314,7 +327,6 @@ def _generate_from_chat_context_alora( format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, ) -> ModelOutputThunk: match action: case ALoraRequirement(): @@ -338,17 +350,25 @@ def _generate_from_chat_context_alora( assert type(user_message) is str assert type(assistant_message) is str assert format is None, "Structured outputs are not supported by ALoRAs." + + model_opts = self._simplify_and_merge(model_options, is_chat_context=True) + alora_output = alora_for_this_request.generate_using_strings( input=user_message, response=assistant_message, constraint=action.description, # type: ignore + stream=model_opts.get(ModelOption.STREAM, False), ) - return self.formatter.parse( - action, - ModelOutputThunk( - alora_output, meta={"alora_name": alora_for_this_request.name} - ), - ) + + # The alora function doesn't set up all the fields. + alora_output._context = linearized_ctx + alora_output._action = action + alora_output._model_options = model_options + + # TODO: Figure out what info we want to populate for aloras here. + alora_output._generate_log = GenerateLog() + + return alora_output @staticmethod def message_to_openai_message(msg: Message): @@ -392,10 +412,8 @@ def _generate_from_chat_context_standard( format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: - # NOTE: Currently, the `thinking` param is going to be set to "medium" if `thinking` is True, else it is None. model_opts = self._simplify_and_merge( model_options, is_chat_context=ctx.is_chat_context ) @@ -456,7 +474,9 @@ def _generate_from_chat_context_standard( formatted_tools = convert_tools_to_json(tools) use_tools = len(formatted_tools) > 0 - chat_response: ChatCompletion = self._client.chat.completions.create( + chat_response: Coroutine[ + Any, Any, ChatCompletion | openai.AsyncStream[ChatCompletionChunk] + ] = self._async_client.chat.completions.create( model=self._hf_model_id, messages=conversation, # type: ignore reasoning_effort=thinking, # type: ignore @@ -468,36 +488,133 @@ def _generate_from_chat_context_standard( ), ) # type: ignore - result = ModelOutputThunk( - value=chat_response.choices[0].message.content, - meta={ - "oai_chat_response": chat_response.choices[0].model_dump() - }, # NOTE: Using model dump here to comply with `TemplateFormatter` - tool_calls=self._extract_model_tool_requests(tools, chat_response), + output = ModelOutputThunk(None) + output._context = linearized_context + output._action = action + output._model_options = model_opts + + # Processing functions only pass the ModelOutputThunk (and current chunk of response). Bind the other vars necessary for + # each processing step. + output._process = self.processing + output._post_process = functools.partial( + self.post_processing, + tools=tools, + conversation=conversation, + thinking=thinking, + seed=model_opts.get(ModelOption.SEED, None), ) - parsed_result = self.formatter.parse(source_component=action, result=result) + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. - if generate_logs is not None: - assert isinstance(generate_logs, list) - generate_log = GenerateLog() - generate_log.prompt = conversation - generate_log.backend = f"openai::{self.model_id!s}" - generate_log.model_options = model_opts - generate_log.date = datetime.datetime.now() - generate_log.model_output = chat_response - generate_log.extra = { - "format": format, - "thinking": thinking, - "tools_available": tools, - "tools_called": result.tool_calls, - "seed": model_opts.get("seed", None), - } - generate_log.action = action - generate_log.result = parsed_result - generate_logs.append(generate_log) + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) + ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present + raise e + + return output + + async def processing( + self, mot: ModelOutputThunk, chunk: ChatCompletion | ChatCompletionChunk + ): + """Called during generation to add information from a single ChatCompletion or ChatCompletionChunk to the ModelOutputThunk. + + For OpenAI, tool call parsing is handled in the post processing step.""" + if mot._thinking is None: + mot._thinking = "" + if mot._underlying_value is None: + mot._underlying_value = "" + + if isinstance(chunk, ChatCompletion): + message = chunk.choices[0].message + + if hasattr(message, "reasoning_content"): + thinking_chunk = message.reasoning_content # type: ignore + if thinking_chunk is not None: + mot._thinking += thinking_chunk + + content_chunk = message.content + if content_chunk is not None: + mot._underlying_value += content_chunk + + mot._meta["oai_chat_response"] = chunk.choices[0].model_dump() + + elif isinstance(chunk, ChatCompletionChunk): + message_delta = chunk.choices[0].delta + if hasattr(message_delta, "reasoning_content"): + thinking_chunk = message_delta.reasoning_content # type: ignore + if thinking_chunk is not None: + mot._thinking += thinking_chunk + + content_chunk = message_delta.content + if content_chunk is not None: + mot._underlying_value += content_chunk + + if mot._meta.get("oai_chat_response_streamed", None) is None: + mot._meta["oai_chat_response_streamed"] = [] + mot._meta["oai_chat_response_streamed"].append( + chunk.choices[0].model_dump() + ) + + async def post_processing( + self, + mot: ModelOutputThunk, + tools: dict[str, Callable], + conversation: list[dict], + thinking, + seed, + ): + """Called when generation is done.""" + # Reconstruct the chat_response from chunks if streamed. + streamed_chunks = mot._meta.get("oai_chat_response_streamed", None) + if streamed_chunks is not None: + mot._meta["oai_chat_response"] = chat_completion_delta_merge( + streamed_chunks + ) - return parsed_result + assert mot._action is not None, ( + "ModelOutputThunks should have their action assigned during generation" + ) + assert mot._model_options is not None, ( + "ModelOutputThunks should have their model_opts assigned during generation" + ) + + # OpenAI streamed responses give you chunks of tool calls. + # As a result, we have to store data between calls and only then + # check for complete tool calls in the post_processing step. + tool_chunk = extract_model_tool_requests(tools, mot._meta["oai_chat_response"]) + if tool_chunk is not None: + if mot.tool_calls is None: + mot.tool_calls = {} + # Merge the tool_chunk dict. + for key, val in tool_chunk.items(): + mot.tool_calls[key] = val + + self.formatter.parse(mot._action, mot) + + # Generate the log for this ModelOutputThunk. + generate_log = GenerateLog() + generate_log.prompt = conversation + generate_log.backend = f"openai::{self.model_id!s}" + generate_log.model_options = mot._model_options + generate_log.date = datetime.datetime.now() + generate_log.model_output = mot._meta["oai_chat_response"] + generate_log.extra = { + "format": format, + "thinking": thinking, + "tools_available": tools, + "tools_called": mot.tool_calls, + "seed": seed, + } + generate_log.action = mot._action + generate_log.result = mot + mot._generate_log = generate_log def _generate_from_raw( self, @@ -571,31 +688,6 @@ def _generate_from_raw( return results - def _extract_model_tool_requests( - self, tools: dict[str, Callable], chat_response: ChatCompletion - ) -> dict[str, ModelToolCall] | None: - model_tool_calls: dict[str, ModelToolCall] = {} - calls = chat_response.choices[0].message.tool_calls - if calls: - for tool_call in calls: - tool_name = tool_call.function.name # type: ignore - tool_args = tool_call.function.arguments # type: ignore - - func = tools.get(tool_name) - if func is None: - FancyLogger.get_logger().warning( - f"model attempted to call a non-existing function: {tool_name}" - ) - continue # skip this function if we can't find it. - - # Returns the args as a string. Parse it here. - args = json.loads(tool_args) - model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args) - - if len(model_tool_calls) > 0: - return model_tool_calls - return None - def add_alora(self, alora: "OpenAIAlora"): """Loads an ALora for this backend. diff --git a/mellea/backends/types.py b/mellea/backends/types.py index 2a0b288d..d7f0db12 100644 --- a/mellea/backends/types.py +++ b/mellea/backends/types.py @@ -17,7 +17,7 @@ class ModelOption: """ TOOLS = "@@@tools@@@" - """Must be a list of callables or a dict[str, Callable].""" + """Must be a list[Callable] or a dict[str, Callable] where str is the name of the function.""" MAX_NEW_TOKENS = "@@@max_new_tokens@@@" SYSTEM_PROMPT = "@@@system_prompt@@@" @@ -25,6 +25,7 @@ class ModelOption: CONTEXT_WINDOW = "@@@context_window@@@" THINKING = "@@@thinking@@@" SEED = "@@@seed@@@" + STREAM = "@@@stream@@@" @staticmethod def replace_keys(options: dict, from_to: dict[str, str]) -> dict[str, Any]: diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index c7ed0b8a..e9400748 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -1,9 +1,12 @@ """A generic WatsonX.ai compatible backend that wraps around the watson_machine_learning library.""" +import asyncio import datetime +import functools import json import os -from collections.abc import Callable +import warnings +from collections.abc import AsyncGenerator, Callable, Coroutine from typing import Any from ibm_watsonx_ai import APIClient, Credentials @@ -19,12 +22,18 @@ convert_tools_to_json, ) from mellea.backends.types import ModelOption +from mellea.helpers.async_helpers import send_to_queue from mellea.helpers.fancy_logger import FancyLogger +from mellea.helpers.openai_compatible_helpers import ( + chat_completion_delta_merge, + extract_model_tool_requests, +) from mellea.stdlib.base import ( CBlock, Component, Context, GenerateLog, + GenerateType, ModelOutputThunk, ModelToolCall, ) @@ -56,6 +65,15 @@ def __init__( api_key : watsonx API key. Defaults to None. project_id : watsonx project ID. Defaults to None. """ + + # There are bugs with the Watsonx python sdk related to async event loops; + # using the same watsonx backend across multiple event loops causes errors. + warnings.warn( + "Watsonx Backend is deprecated, use 'LiteLLM' or 'OpenAI' Backends instead", + DeprecationWarning, + 2, + ) + super().__init__( model_id=model_id, formatter=( @@ -91,6 +109,7 @@ def __init__( "system": ModelOption.SYSTEM_PROMPT, "max_tokens": ModelOption.MAX_NEW_TOKENS, "tools": ModelOption.TOOLS, + "stream": ModelOption.STREAM, } # A mapping of Mellea specific ModelOptions to the specific names for this backend. # These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`. @@ -105,6 +124,7 @@ def __init__( self.to_mellea_model_opts_map_completions = { "random_seed": ModelOption.SEED, "max_new_tokens": ModelOption.MAX_NEW_TOKENS, + "stream": ModelOption.STREAM, } # See notes above. self.from_mellea_model_opts_map_completions = { @@ -192,7 +212,6 @@ def generate_from_context( *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ): """See `generate_from_chat_context`.""" @@ -204,7 +223,6 @@ def generate_from_context( ctx, format=format, model_options=model_options, - generate_logs=generate_logs, tool_calls=tool_calls, ) @@ -216,7 +234,6 @@ def generate_from_chat_context( format: type[BaseModelSubclass] | None = None, # Type[BaseModelSubclass] is a class object of a subclass of BaseModel model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, tool_calls: bool = False, ) -> ModelOutputThunk: """Generates a new completion from the provided Context using this backend's `Formatter`.""" @@ -274,47 +291,159 @@ def generate_from_chat_context( FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") formatted_tools = convert_tools_to_json(tools) - chat_response = self._model.chat( - messages=conversation, - tools=formatted_tools, - tool_choice_option=( - "auto" if formatted_tools and len(formatted_tools) > 0 else "none" - ), - params=self._make_backend_specific_and_remove( - model_opts, is_chat_context=ctx.is_chat_context - ), - ) - # If a tool is called, there might not be content in the message. - response_message = chat_response["choices"][0]["message"].get("content", "") - result = ModelOutputThunk( - value=response_message, - meta={"oai_chat_response": chat_response["choices"][0]}, - tool_calls=self._extract_model_tool_requests(tools, chat_response), + chat_response: ( + Coroutine[Any, Any, AsyncGenerator] | Coroutine[Any, Any, dict] | None + ) = None + + stream = model_opts.get(ModelOption.STREAM, False) + if stream: + chat_response = self._model.achat_stream( + messages=conversation, + tools=formatted_tools, + tool_choice_option=( + "auto" if formatted_tools and len(formatted_tools) > 0 else "none" + ), + params=self._make_backend_specific_and_remove( + model_opts, is_chat_context=ctx.is_chat_context + ), + ) + else: + chat_response = self._model.achat( + messages=conversation, + tools=formatted_tools, + tool_choice_option=( + "auto" if formatted_tools and len(formatted_tools) > 0 else "none" + ), + params=self._make_backend_specific_and_remove( + model_opts, is_chat_context=ctx.is_chat_context + ), + ) + + output = ModelOutputThunk(None) + output._context = linearized_context + output._action = action + output._model_options = model_opts + + # Processing functions only pass the ModelOutputThunk (and current chunk of response). Bind the other vars necessary for + # each processing step. + output._process = self.processing + output._post_process = functools.partial( + self.post_processing, + conversation=conversation, + tools=tools, + seed=model_opts.get(ModelOption.SEED, None), ) - parsed_result = self.formatter.parse(source_component=action, result=result) + try: + # To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine. + # We can also support synchronous calls by adding a flag and changing this ._generate function. - if generate_logs is not None: - assert isinstance(generate_logs, list) - generate_log = GenerateLog() - generate_log.prompt = conversation - generate_log.backend = f"watsonx::{self.model_id!s}" - generate_log.model_options = model_opts - generate_log.date = datetime.datetime.now() - generate_log.model_output = chat_response - generate_log.extra = { - "format": format, - # "thinking": thinking, - "tools_available": tools, - "tools_called": result.tool_calls, - "seed": model_opts.get(ModelOption.SEED, None), - } - generate_log.result = parsed_result - generate_log.action = action - generate_logs.append(generate_log) + # This function should always be called from a running event loop so we don't have to worry about + # scheduling the task to a specific event loop here. + output._generate = asyncio.create_task( + send_to_queue(chat_response, output._async_queue) + ) + output._generate_type = GenerateType.ASYNC + except RuntimeError as e: + # Most likely cause is running this function without an event loop present + raise e + + return output - return parsed_result + async def processing(self, mot: ModelOutputThunk, chunk: dict): + """Called during generation to add information from a single ChatCompletion or ChatCompletionChunk to the ModelOutputThunk. + + For OpenAI-like APIs, tool call parsing is handled in the post processing step.""" + if mot._thinking is None: + mot._thinking = "" + if mot._underlying_value is None: + mot._underlying_value = "" + + if len(chunk["choices"]) < 1: + return # Empty chunk. Note: this has some metadata information, but ignoring for now. + + # Watsonx returns dicts. Distinguish streaming and non-streaming based on their fields. + not_streaming = chunk["choices"][0].get("message", None) is not None + if not_streaming: + message: dict = chunk["choices"][0].get("message", dict()) + + thinking_chunk = message.get("reasoning_content", None) + if thinking_chunk is not None: + mot._thinking += thinking_chunk + + content_chunk = message.get("content", "") + if content_chunk is not None: + mot._underlying_value += content_chunk + + mot._meta["oai_chat_response"] = chunk["choices"][0] + + else: # Streaming. + message_delta: dict = chunk["choices"][0].get("delta", dict()) + + thinking_chunk = message_delta.get("reasoning_content", None) + if thinking_chunk is not None: + mot._thinking += thinking_chunk + + content_chunk = message_delta.get("content", None) + if content_chunk is not None: + mot._underlying_value += content_chunk + + if mot._meta.get("oai_chat_response_streamed", None) is None: + mot._meta["oai_chat_response_streamed"] = [] + mot._meta["oai_chat_response_streamed"].append(chunk["choices"][0]) + + async def post_processing( + self, + mot: ModelOutputThunk, + conversation: list[dict], + tools: dict[str, Callable], + seed, + ): + """Called when generation is done.""" + # Reconstruct the chat_response from chunks if streamed. + streamed_chunks = mot._meta.get("oai_chat_response_streamed", None) + if streamed_chunks is not None: + mot._meta["oai_chat_response"] = chat_completion_delta_merge( + streamed_chunks + ) + + assert mot._action is not None, ( + "ModelOutputThunks should have their action assigned during generation" + ) + assert mot._model_options is not None, ( + "ModelOutputThunks should have their model_opts assigned during generation" + ) + + # OpenAI streamed responses give you chunks of tool calls. + # As a result, we have to store data between calls and only then + # check for complete tool calls in the post_processing step. + tool_chunk = extract_model_tool_requests(tools, mot._meta["oai_chat_response"]) + if tool_chunk is not None: + if mot.tool_calls is None: + mot.tool_calls = {} + # Merge the tool_chunk dict. + for key, val in tool_chunk.items(): + mot.tool_calls[key] = val + + self.formatter.parse(mot._action, mot) + + # Generate the log for this ModelOutputThunk. + generate_log = GenerateLog() + generate_log.prompt = conversation + generate_log.backend = f"watsonx::{self.model_id!s}" + generate_log.model_options = mot._model_options + generate_log.date = datetime.datetime.now() + generate_log.model_output = mot._meta["oai_chat_response"] + generate_log.extra = { + "format": format, + "tools_available": tools, + "tools_called": mot.tool_calls, + "seed": seed, + } + generate_log.result = mot + generate_log.action = mot._action + mot._generate_log = generate_log def _generate_from_raw( self, diff --git a/mellea/helpers/async_helpers.py b/mellea/helpers/async_helpers.py new file mode 100644 index 00000000..d3ed2744 --- /dev/null +++ b/mellea/helpers/async_helpers.py @@ -0,0 +1,31 @@ +import asyncio +from collections.abc import AsyncIterator, Coroutine +from typing import Any + + +async def send_to_queue( + co: Coroutine[Any, Any, AsyncIterator | Any] | AsyncIterator, aqueue: asyncio.Queue +) -> None: + """Processes the output of an async chat request by sending the output to an async queue.""" + try: + if isinstance(co, Coroutine): + aresponse = await co + else: + # Some backends (hf) don't actually return their iterator from an + # async function. As a result, there's no coroutine to wait for here. + aresponse = co + + if isinstance(aresponse, AsyncIterator): + async for item in aresponse: + await aqueue.put(item) + else: + await aqueue.put(aresponse) + + # Always add a sentinel value to indicate end of stream. + await aqueue.put(None) + + # Typically, nothing awaits this function directly (only through the queue). + # As a result, we have to be careful about catching all errors and propagating + # them to the queue. + except Exception as e: + await aqueue.put(e) diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py new file mode 100644 index 00000000..87daf0bc --- /dev/null +++ b/mellea/helpers/openai_compatible_helpers.py @@ -0,0 +1,118 @@ +"""A file for helper functions that deal with OpenAI API compatible helpers.""" + +import json +from collections.abc import Callable +from typing import Any + +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import ModelToolCall + + +def extract_model_tool_requests( + tools: dict[str, Callable], response: dict[str, Any] +) -> dict[str, ModelToolCall] | None: + """Extracts tool calls from the dict representation of an OpenAI-like chat response object.""" + model_tool_calls: dict[str, ModelToolCall] = {} + calls = response["message"].get("tool_calls", None) + if calls: + for tool_call in calls: + tool_name = tool_call["function"]["name"] # type: ignore + tool_args = tool_call["function"]["arguments"] # type: ignore + + func = tools.get(tool_name) + if func is None: + FancyLogger.get_logger().warning( + f"model attempted to call a non-existing function: {tool_name}" + ) + continue # skip this function if we can't find it. + + args = {} + if tool_args is not None: + # Returns the args as a string. Parse it here. + args = json.loads(tool_args) + model_tool_calls[tool_name] = ModelToolCall(tool_name, func, args) + + if len(model_tool_calls) > 0: + return model_tool_calls + return None + + +def chat_completion_delta_merge( + chunks: list[dict], force_all_tool_calls_separate: bool = False +) -> dict: + """Takes a list of deltas from `ChatCompletionChunk`s and merges them into a single dict representing the `ChatCompletion` choice. + + Args: + chunks: the list of dicts that represent the message deltas + force_all_tool_calls_separate: if `True`, tool calls in separate message deltas will not be merged (even if their index values are the same); use when providers do not return the correct index value for tool calls. If using this option, all tool calls must be fully populated in a single delta since they won't be merged. + """ + + merged: dict[str, Any] = dict() + + # `delta`s map to a single choice. + merged["finish_reason"] = None + merged["index"] = 0 # We always do the first choice. + merged["logprobs"] = None + merged["stop_reason"] = None + + # message fields + message: dict[str, Any] = dict() + message["content"] = "" + message["reasoning_content"] = "" + message["role"] = None + m_tool_calls: list[dict] = [] + message["tool_calls"] = m_tool_calls + merged["message"] = message + + for chunk in chunks: + # Handle top level fields. + if chunk.get("finish_reason", None) is not None: + merged["finish_reason"] = chunk["finish_reason"] + if chunk.get("stop_reason", None) is not None: + merged["stop_reason"] = chunk["stop_reason"] + + # Handle fields of the message object. + if message["role"] is None and chunk["delta"].get("role", None) is not None: + message["role"] = chunk["delta"]["role"] + + if chunk["delta"].get("content", None) is not None: + message["content"] += chunk["delta"]["content"] + + thinking = chunk["delta"].get("reasoning_content", None) + if thinking is not None: + message["reasoning_content"] += thinking + + tool_calls = chunk["delta"].get("tool_calls", None) + if tool_calls is not None: + # Merge the pieces of each tool call from separate chunks into one dict. + # Example: + # chunks: [{'arguments': None, 'name': 'get_weather_precise'}, {'arguments': '{"location": "', 'name': None}, {'arguments': 'Dallas}', 'name': None}] + # -> [{'arguments': '{"location": "Dallas"}', 'name': 'get_weather_precise'}] + for tool_call in tool_calls: + idx: int = tool_call["index"] + current_tool = None + + # In a few special cases, we want to force all tool calls to be separate regardless of the index value. + # If not forced, check that the tool call index in the response isn't already in our list. + create_new_tool_call = force_all_tool_calls_separate or ( + idx > len(m_tool_calls) - 1 + ) + if create_new_tool_call: + current_tool = {"function": {"name": "", "arguments": None}} + m_tool_calls.append(current_tool) + else: + # This tool has already started to be defined. + current_tool = m_tool_calls[idx] + + # Get the info from the function chunk. + fx_info = tool_call["function"] + if fx_info["name"] is not None: + current_tool["function"]["name"] += fx_info["name"] + + if fx_info["arguments"] is not None: + # Only populate args if there are any to add. + if current_tool["function"]["arguments"] is None: + current_tool["function"]["arguments"] = "" + current_tool["function"]["arguments"] += fx_info["arguments"] + + return merged diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index 960a3203..d7ad307c 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -3,11 +3,13 @@ from __future__ import annotations import abc +import asyncio import base64 import binascii import datetime -from collections.abc import Callable, Iterable, Mapping -from copy import deepcopy +import enum +from collections.abc import Callable, Coroutine, Iterable, Mapping +from copy import copy, deepcopy from dataclasses import dataclass from io import BytesIO from typing import Any, Protocol, runtime_checkable @@ -131,7 +133,7 @@ def format_for_llm(self) -> TemplateRepresentation | str: def get_images_from_component(c: Component) -> None | list[ImageBlock]: """Gets images from a `Component` if they are present and a non-empty list, otherwise returns None.""" if hasattr(c, "images"): - imgs = c.images + imgs = c.images # type: ignore if imgs is not None: assert isinstance(imgs, list), "images field must be a list." assert all(isinstance(im, ImageBlock) for im in imgs), ( @@ -147,6 +149,14 @@ def get_images_from_component(c: Component) -> None | list[ImageBlock]: return None +class GenerateType(enum.Enum): + """Used to track what functions can be used to extract a value from a ModelOutputThunk.""" + + NONE = None + ASYNC = 1 + SYNC = 2 + + class ModelOutputThunk(CBlock): """A `ModelOutputThunk` is a special type of `CBlock` that we know came from a model's output. It is possible to instantiate one without the output being computed yet.""" @@ -160,11 +170,156 @@ def __init__( """Initializes as a cblock, optionally also with a parsed representation from an output formatter.""" super().__init__(value, meta) self.parsed_repr: CBlock | Component | Any | None = parsed_repr + + # Set computed to True if a value is passed in. + self._computed: bool = True if value is not None else False + + # Additional fields that should be standardized across apis. self.tool_calls = tool_calls + self._thinking: str | None = None + + # Used for tracking generation. + self._context: list[Component | CBlock] | None = None + self._action: Component | CBlock | None = None + self._model_options: dict[str, Any] | None = None + + # Used for async and async streaming. + self._async_queue: asyncio.Queue = asyncio.Queue(maxsize=20) + self._chunk_size = 3 # Minimum number of chunks to stream at a single time. + + # _generate and _generate_type are linked. _generate will determine + # what gets set for _generate_type. _generate_type determines what + # function(s) can be used to get the value of the ModelOutputThunk. + self._generate: asyncio.Task[None] | None = None + self._generate_type: GenerateType = GenerateType.NONE + self._generate_extra: asyncio.Task[Any] | None = ( + None # Currently only used by hf. + ) + self._process: Callable[[ModelOutputThunk, Any], Coroutine] | None = None + self._post_process: Callable[[ModelOutputThunk], Coroutine] | None = None + + self._generate_log: GenerateLog | None = None def is_computed(self): """Returns true only if this Thunk has already been filled.""" - return self.value is not None + return self._computed + + @property + def value(self) -> str | None: + """Gets the value of the block.""" + if not self._computed: + return None + return self._underlying_value + + @value.setter + def value(self, v: str): + """Sets the value of the block.""" + self._underlying_value = v + + async def avalue(self) -> str: + """Returns the value of the ModelOutputThunk. Can be used for both async streaming and async non-streaming. + + Raises: + Exception: Propagates any errors from the underlying inference engine api request. + RuntimeError: If called when the ModelOutputThunk's generate function is not async compatible. + """ + if self._computed: + assert self.value # If computed, the value cannot be None. + return self.value + + if not self._generate_type == GenerateType.ASYNC: + raise RuntimeError( + f"Cannot use `ModelOutputThunk.avalue()` when the generate function is using `{self._generate_type.name}`" + ) + + while not self._computed: + await self.astream() + + assert self.value is not None # If computed, the value cannot be None. + return self.value + + # If we require a function that returns only the new chunks of data, we can implement that similarly. + async def astream(self) -> str: + """Returns the ModelOutputThunk's partial value including the next chunk(s). Can be used for both async streaming and async non-streaming. + + Returns the value of the ModelOutputThunk if streaming is done. + + **Note**: Be careful with calling this function. Only call it from one location at a time. This means you shouldn't pass a ModelOutputThunk to + multiple coroutines/tasks and call astream from those coroutines/tasks simultaneously. We have considered solutions to this but are waiting until + we see this error happen in a real use case. + + Raises: + Exception: Propagates any errors from the underlying inference engine api request. + RuntimeError: If called when the ModelOutputThunk's generate function is not async compatible. + """ + if self._computed: + assert self.value is not None # If computed, the value cannot be None. + return self.value + + if not self._generate_type == GenerateType.ASYNC: + raise RuntimeError( + f"Cannot use `ModelOutputThunk.astream()` when the generate function is using `{self._generate_type.name}`" + ) + + # Type of the chunk depends on the backend. + chunks: list[Any | None] = [] + while True: + try: + item = self._async_queue.get_nowait() + chunks.append(item) + except asyncio.QueueEmpty: + # We've exhausted the current items in the queue. + break + + # Make sure we always get the minimum chunk size. + while len(chunks) <= self._chunk_size: + if len(chunks) > 0: + if chunks[-1] is None or isinstance(chunks[-1], Exception): + break # Hit sentinel value or an error. + # We could switch to relying on the `done` / `finish_reason` field of chunks, + # but that forces us to know about the chunk type here. Prefer sentinel values + # for now. + + item = await self._async_queue.get() + chunks.append(item) + + # Process the sentinel value if it's there. + if chunks[-1] is None: + chunks.pop() # Remove the sentinel value. + self._computed = True + + # Shouldn't be needed, but cancel the Tasks this ModelOutputThunk relied on. + if self._generate is not None: + self._generate.cancel() + if self._generate_extra is not None: + # Covers an hf edge case. The task is done generating anything useful but isn't `done` yet. + await self._generate_extra + self._generate_extra.cancel() + + # If ModelOutputThunks get too bulky, we can do additional cleanup here + # and set fields to None. + + elif isinstance(chunks[-1], Exception): + # For now, just re-raise the exception. + # It's possible that we hit this error after already streaming some + # chunks. We should investigate allowing recovery in the future. + raise chunks[-1] + + for chunk in chunks: + assert self._process is not None + await self._process(self, chunk) + + if self._computed: + assert self._post_process is not None + await self._post_process(self) + + return self._underlying_value # type: ignore + + def __repr__(self): + """Provides a python-parsable representation (usually). + + Differs from CBlock because `._meta` can be very large for ModelOutputThunks.""" + return f"ModelOutputThunk({self.value})" def blockify(s: str | CBlock | Component) -> CBlock | Component: @@ -236,7 +391,9 @@ def insert_turn( @abc.abstractmethod def copy(self) -> Context: - """Produces a deep copy of the current Context's contents, allowing for branch-and-merge style semantics over a Context.""" + """Produces a copy of the current Context's contents, allowing for branch-and-merge style semantics over a Context. + + Implementations should not copy the actual objects in the context but retain a reference to them.""" ... @abc.abstractmethod @@ -397,6 +554,16 @@ def __str__(self): [f" {c!s}" for c in self._ctx] ) + def copy(self): + """Copies all attributes of the Context. `_ctx` and `_log_ctx` are shallow copies. + + This means that the lists are different (you can independently insert to the new/old context), but that the objects in the old/new lists are the same at copy time. + """ + new = copy(self) + new._ctx = copy(self._ctx) + new._log_ctx = copy(self._log_ctx) + return new + class LinearContext(BasicContext): """Initializes a linear context with unbounded window_size and is_chat=True by default.""" @@ -465,10 +632,6 @@ def _hash_for_kv_cache(self): """Constructs a hash that corresponds to the string contents of the KV cache associated with this context.""" assert False, "not supported yet." - def copy(self): - """Constructs a deep copy of this Context.""" - return deepcopy(self) - class SimpleContext(BasicContext): """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved. @@ -520,10 +683,6 @@ def _hash_for_kv_cache(self): """Constructs a hash that corresponds to the string contents of the KV cache associated with this context.""" assert False, "not supported yet." - def copy(self): - """Constructs a deep copy of this Context.""" - return deepcopy(self) - @dataclass class TemplateRepresentation: diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index 495092cb..0b0b0061 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -41,7 +41,12 @@ class ValidationResult: """ValidationResults store the output of a Requirement's validation. They can be used to return additional info from validation functions, which is useful for sampling/repairing.""" def __init__( - self, result: bool, *, reason: str | None = None, score: float | None = None + self, + result: bool, + *, + reason: str | None = None, + score: float | None = None, + thunk: ModelOutputThunk | None = None, ): """The result of a requirement's validation. @@ -51,10 +56,12 @@ def __init__( result: a boolean that is true if the requirement passed reason: a reason for the result score: if your validator gives you a score back, you can add this as metadata + thunk: if your validator utilizes a backend to generate a response, the ModelOutputThunk returned from that request """ self._result = result self._reason = reason self._score = score + self._thunk = thunk @property def reason(self) -> str | None: @@ -64,6 +71,10 @@ def reason(self) -> str | None: def score(self) -> float | None: return self._score + @property + def thunk(self) -> ModelOutputThunk | None: + return self._thunk + def as_bool(self) -> bool: """""" return self._result @@ -101,14 +112,13 @@ def __init__( # Used for validation. Do not manually populate. self._output: str | None = None - def validate( + async def validate( self, backend: Backend, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, ) -> ValidationResult: """Chooses the appropriate validation strategy and applies that strategy.""" if self.validation_fn is not None: @@ -127,15 +137,14 @@ def validate( req_copy = copy(self) req_copy._output = last_output.value llm_as_a_judge_result = backend.generate_from_context( - req_copy, - ctx, - format=format, - model_options=model_options, - generate_logs=generate_logs, + req_copy, ctx, format=format, model_options=model_options ) + await llm_as_a_judge_result.avalue() + return ValidationResult( result=self.output_to_bool(llm_as_a_judge_result), reason=llm_as_a_judge_result.value, + thunk=llm_as_a_judge_result, ) def parts(self): @@ -210,14 +219,13 @@ def __init__( raise NotImplementedError self.preference_ordering: str = preference_ordering.lower() - def validate( + async def validate( self, backend: Backend, ctx: Context, *, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, ) -> ValidationResult: """Chooses the appropriate validation strategy and applies that strategy. Asserts that the returned ValidationResult has a valid score.""" if self.validation_fn is not None: @@ -241,18 +249,16 @@ def validate( req_copy = copy(self) req_copy._output = last_output.value llm_as_a_judge_result = backend.generate_from_context( - req_copy, - ctx, - format=format, - model_options=model_options, - generate_logs=generate_logs, + req_copy, ctx, format=format, model_options=model_options ) + await llm_as_a_judge_result.avalue() result = self.output_to_bool(llm_as_a_judge_result) return ValidationResult( result=result, reason=llm_as_a_judge_result.value, score=1 if result else 0, + thunk=llm_as_a_judge_result, ) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 845e3730..ff7ab3a2 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -1,7 +1,7 @@ """sampling methods go here.""" import abc -from collections.abc import Callable +from collections.abc import Callable, Coroutine from copy import deepcopy from typing import Any @@ -60,22 +60,22 @@ class SamplingStrategy(abc.ABC): # the function signature here matches that of m.validate validate: ( - Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None - ) = None - - generate: ( - Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + Callable[ + [list[Requirement], Context, Any, Any], + Coroutine[Any, Any, list[ValidationResult]], + ] | None ) = None + generate: Callable[[Component, Context], ModelOutputThunk] | None = None + @abc.abstractmethod - def sample( + async def sample( self, action: Component, context: Context, requirements: list[Requirement], *, - generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method is the abstract method for sampling a given instruction. @@ -86,7 +86,6 @@ def sample( action : The action object to be sampled. context: The context to be passed to the sampling strategy. requirements: The requirements to be used by the sampling strategy (merged with global requirements). - generate_logs: Optional list of GenerateLog objects. If None, no collection happens. validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. """ @@ -100,12 +99,12 @@ def __init__( self, *, loop_budget: int = 1, - validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + validate: Callable[ + [list[Requirement], Context, Any, Any], + Coroutine[Any, Any, list[ValidationResult]], + ] | None = None, - generate: ( - Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] - | None - ) = None, + generate: (Callable[[Component, Context], ModelOutputThunk] | None) = None, requirements: list[Requirement] | None = None, ): """Initialize a new instance of the class with default parameters. @@ -167,14 +166,13 @@ def select_from_failure( """ ... - def sample( + async def sample( self, action: Component, context: Context, requirements: list[Requirement], *, show_progress: bool = True, - generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -183,7 +181,6 @@ def sample( action : The action object to be sampled. context: The context to be passed to the sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. - generate_logs: If provided, the generations will be logged. requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. @@ -234,10 +231,17 @@ def sample( flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result = self.generate(new_action, ctx, generate_logs) + result = self.generate(new_action, ctx) + await result.avalue() # validation pass - val_scores = self.validate(reqs, validation_ctx, result) + val_scores_co = self.validate( + reqs, + validation_ctx, + result, + input=None, # type: ignore + ) + val_scores = await val_scores_co # match up reqs with scores constraint_scores = list(zip(reqs, val_scores)) @@ -250,6 +254,11 @@ def sample( # if all vals are true -- break and return success if all(bool(s[1]) for s in constraint_scores): flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + return SamplingResult( result, success=True, @@ -278,6 +287,12 @@ def sample( assert best_failed_index < len(sampled_results), ( "The select_from_failure method did not return a valid result. It has to selected from failed_results." ) + + assert ( + sampled_results[best_failed_index]._generate_log is not None + ) # Cannot be None after generation. + sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore + return SamplingResult( sampled_results[best_failed_index], success=False, @@ -388,14 +403,13 @@ class BestofNSamplingStrategy(BaseSamplingStrategy): Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer """ - def sample( + async def sample( self, action: Component, context: Context, requirements: list[Requirement], *, show_progress: bool = True, - generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -404,7 +418,6 @@ def sample( action : The action object to be sampled. context: The context to be passed to the sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. - generate_logs: If provided, the generations will be logged. requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. @@ -471,16 +484,18 @@ def sample( flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result = self.generate(new_action, ctx, generate_logs) + result = self.generate(new_action, ctx) + await result.avalue() # validation pass # action has user turn - val_scores = self.validate( + val_scores_co = self.validate( reqs, validation_ctx, result, input=action._description, # type: ignore ) + val_scores = await val_scores_co # match up reqs with scores constraint_scores = list(zip(reqs, val_scores)) @@ -494,6 +509,11 @@ def sample( # if all vals are true, save it and continue to get next sample if all(bool(s[1]) for s in constraint_scores): flog.info("SUCCESS") + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + result._generate_log.is_final_result = True + successful_sampled_results.append(result) successful_sampled_scores.append(constraint_scores) successful_sampled_actions.append(new_action) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index dee6a067..47bf0581 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -2,7 +2,10 @@ from __future__ import annotations +import asyncio import contextvars +import threading +from collections.abc import Coroutine from copy import deepcopy from typing import Any, Literal, overload @@ -183,6 +186,10 @@ def __init__(self, backend: Backend, ctx: Context | None = None): self._session_logger = FancyLogger.get_logger() self._context_token = None + # Necessary for async. `m.*` functions should always run in this event loop. + self._event_loop = asyncio.new_event_loop() + threading.Thread(target=self._event_loop.run_forever, daemon=True).start() + def __enter__(self): """Enter context manager and set this session as the current global session.""" self._context_token = _context_session.set(self) @@ -195,6 +202,9 @@ def __exit__(self, exc_type, exc_val, exc_tb): _context_session.reset(self._context_token) self._context_token = None + def __del__(self): + self._close_event_loop() + def _push_model_state(self, new_backend: Backend, new_model_opts: dict): """The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`. @@ -226,11 +236,36 @@ def reset(self): def cleanup(self) -> None: """Clean up session resources.""" + self._close_event_loop() self.reset() self._backend_stack.clear() if hasattr(self.backend, "close"): self.backend.close() # type: ignore + def _close_event_loop(self) -> None: + """Called when deleting the session. Cleans up the session's event loop.""" + if self._event_loop: + try: + tasks = asyncio.all_tasks(self._event_loop) + for task in tasks: + task.cancel() + + async def finalize_tasks(): + # TODO: We can log errors here if needed. + await asyncio.gather(*tasks, return_exceptions=True) + + out = asyncio.run_coroutine_threadsafe( + finalize_tasks(), self._event_loop + ) + + # Timeout if needed. + out.result(5) + except Exception: + pass + + # Finally stop the event loop for this session. + self._event_loop.stop() + def summarize(self) -> ModelOutputThunk: """Summarizes the current context.""" raise NotImplementedError() @@ -297,6 +332,48 @@ def act( ) -> ModelOutputThunk | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. + Args: + action: the Component from which to generate. + requirements: used as additional requirements when a sampling strategy is provided + strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: if set, the BaseModel to use for constrained decoding. + model_options: additional model options, which will upsert into the model/backend's defaults. + tool_calls: if true, tool calling is enabled. + + Returns: + A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + + # Run everything in the specific event loop for this session. + out = asyncio.run_coroutine_threadsafe( + self._act( + action, + requirements=requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ), + self._event_loop, + ) + + return out.result() + + async def _act( + self, + action: Component, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = None, + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + ) -> ModelOutputThunk | SamplingResult: + """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. + Args: action: the Component from which to generate. requirements: used as additional requirements when a sampling strategy is provided @@ -323,32 +400,32 @@ def act( ctx=self.ctx, format=format, model_options=model_options, - generate_logs=generate_logs, tool_calls=tool_calls, ) - assert len(generate_logs) == 1, "Simple call can only add one generate_log" - generate_logs[-1].is_final_result = True + await result.avalue() + + # ._generate_log should never be None after generation. + assert result._generate_log is not None + result._generate_log.is_final_result = True + generate_logs.append(result._generate_log) else: # Default validation strategy just validates all of the provided requirements. if strategy.validate is None: strategy.validate = ( - lambda reqs, val_ctx, output, input=None: self.validate( # type: ignore + lambda reqs, val_ctx, output, input=None: self._validate( # type: ignore reqs, output=output, input=input ) - ) # type: ignore + ) # Default generation strategy just generates from context. if strategy.generate is None: strategy.generate = ( - lambda sample_action, - gen_ctx, - g_logs: self.backend.generate_from_context( + lambda sample_action, gen_ctx: self.backend.generate_from_context( sample_action, ctx=gen_ctx, format=format, model_options=model_options, - generate_logs=g_logs, tool_calls=tool_calls, ) ) @@ -356,25 +433,22 @@ def act( if requirements is None: requirements = [] - sampling_result = strategy.sample( - action, self.ctx, requirements=requirements, generate_logs=generate_logs + sampling_result = await strategy.sample( + action, self.ctx, requirements=requirements ) - # make sure that one Log is marked as the one related to sampling_result.result - if sampling_result.success: - # if successful, the last log is the one related - generate_logs[-1].is_final_result = True - else: - # Find the log where log.result and sampling_result.result match - selected_log = [ - log for log in generate_logs if log.result == sampling_result.result - ] - assert len(selected_log) == 1, ( - "There should only be exactly one log corresponding to the single result. " - ) - selected_log[0].is_final_result = True + assert sampling_result.sample_generations is not None + for result in sampling_result.sample_generations: + assert ( + result._generate_log is not None + ) # Cannot be None after generation. + generate_logs.append(result._generate_log) result = sampling_result.result + assert sampling_result.result._generate_log is not None + assert sampling_result.result._generate_log.is_final_result, ( + "generate logs from the final result returned by the sampling strategy must be marked as final" + ) self.ctx.insert_turn(ContextTurn(action, result), generate_logs=generate_logs) @@ -530,6 +604,33 @@ def validate( input: CBlock | None = None, ) -> list[ValidationResult]: """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" + # Run everything in the specific event loop for this session. + out = asyncio.run_coroutine_threadsafe( + self._validate( + reqs=reqs, + output=output, + format=format, + model_options=model_options, + generate_logs=generate_logs, + input=input, + ), + self._event_loop, + ) + + # Wait for and return the result. + return out.result() + + async def _validate( + self, + reqs: Requirement | list[Requirement], + *, + output: CBlock | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] | None = None, + input: CBlock | None = None, + ) -> list[ValidationResult]: + """Asynchronous version of .validate; validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" # Turn a solitary requirement in to a list of requirements, and then reqify if needed. reqs = [reqs] if not isinstance(reqs, list) else reqs reqs = [Requirement(req) if type(req) is str else req for req in reqs] @@ -537,6 +638,7 @@ def validate( validation_target_ctx = self.ctx else: validation_target_ctx = SimpleContext() + if input is not None: # some validators may need input as well as output validation_target_ctx.insert_turn( @@ -548,17 +650,38 @@ def validate( ) else: validation_target_ctx.insert(output) - rvs = [] + + rvs: list[ValidationResult] = [] + coroutines: list[Coroutine[Any, Any, ValidationResult]] = [] + for requirement in reqs: - val_result = requirement.validate( + val_result_co = requirement.validate( self.backend, validation_target_ctx, format=format, model_options=model_options, - generate_logs=generate_logs, ) + coroutines.append(val_result_co) + + for val_result in await asyncio.gather(*coroutines): rvs.append(val_result) + # If the validator utilized a backend to generate a result, attach the corresponding + # info to the generate_logs list. + if generate_logs is not None: + if val_result.thunk is not None: + thunk = val_result.thunk + assert ( + thunk._generate_log is not None + ) # Cannot be None after generation. + generate_logs.append(thunk._generate_log) + else: + # We have to append None here so that the logs line-up. + # TODO: A better solution should be found for this edge case. + # This is the only scenario where ValidationResults are supposed to line + # up with GenerateLogs. + generate_logs.append(None) # type: ignore + return rvs def query( diff --git a/pyproject.toml b/pyproject.toml index 32d0afc2..8263f2fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -146,6 +146,7 @@ ignore = [ # "UP007", # Option and Union # "UP035", # `typing.Set` is deprecated, use `set` instead" "PD901", # Avoid using the generic variable name `df` for DataFrames + "C901", # Complexity warnings ] [tool.ruff.lint.pydocstyle] @@ -159,7 +160,7 @@ combine-as-imports = true split-on-trailing-comma = false [tool.codespell] -ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd' +ignore-words-list = 'mellea,hashi,noo,Asai,asai,nd,mot' check-filenames = true check-hidden = false regex = "(? bool: # should yield to true - but, of course, is model dependent assert h is True +@pytest.mark.qualitative +def test_async_parallel_requests(session): + async def parallel_requests(): + model_opts = {ModelOption.STREAM: True} + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value + asyncio.run(parallel_requests()) + +@pytest.mark.qualitative +def test_async_avalue(session): + async def avalue(): + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value + asyncio.run(avalue()) if __name__ == "__main__": import pytest diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index b90d93fb..8e9b8631 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -1,3 +1,4 @@ +import asyncio import json import pydantic @@ -126,6 +127,41 @@ class Answer(pydantic.BaseModel): f"formatting directive failed for {random_result.value}: {e.json()}" ) +def test_async_parallel_requests(session): + async def parallel_requests(): + model_opts = {ModelOption.STREAM: True} + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value + asyncio.run(parallel_requests()) + +def test_async_avalue(session): + async def avalue(): + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value + asyncio.run(avalue()) if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index def41004..d773e645 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -1,4 +1,5 @@ # test/rits_backend_tests/test_openai_integration.py +import asyncio import os import pydantic @@ -10,7 +11,7 @@ from mellea.backends.model_ids import META_LLAMA_3_2_1B from mellea.backends.openai import OpenAIBackend from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk +from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk, SimpleContext @pytest.fixture(scope="module") @@ -18,11 +19,11 @@ def backend(gh_run: int): """Shared OpenAI backend configured for Ollama.""" if gh_run == 1: return OpenAIBackend( - model_id=META_LLAMA_3_2_1B, - formatter=TemplateFormatter(model_id=META_LLAMA_3_2_1B), - base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1", - api_key="ollama", - ) + model_id=META_LLAMA_3_2_1B.ollama_name, # type: ignore + formatter=TemplateFormatter(model_id=META_LLAMA_3_2_1B.hf_model_name), # type: ignore + base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1", + api_key="ollama", + ) else: return OpenAIBackend( model_id="granite3.3:8b", @@ -39,12 +40,14 @@ def m_session(backend): yield session session.reset() + @pytest.mark.qualitative def test_instruct(m_session): result = m_session.instruct("Compute 1+1.") assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore + @pytest.mark.qualitative def test_multiturn(m_session): m_session.instruct("What is the capital of France?") @@ -66,6 +69,15 @@ def test_multiturn(m_session): # assert "granite3.3:8b" in result.value # self.m.reset() + +@pytest.mark.qualitative +def test_chat(m_session): + output_message = m_session.chat("What is 1+1?") + assert "2" in output_message.content, ( + f"Expected a message with content containing 2 but found {output_message}" + ) + + @pytest.mark.qualitative def test_format(m_session): class Person(pydantic.BaseModel): @@ -132,6 +144,56 @@ class Email(pydantic.BaseModel): # assert False, f"formatting directive failed for {random_result.value}: {e.json()}" +def test_async_parallel_requests(m_session): + async def parallel_requests(): + model_opts = {ModelOption.STREAM: True} + mot1 = m_session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext(), model_options=model_opts + ) + mot2 = m_session.backend.generate_from_context( + CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts + ) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), ( + "final val should contain the first streamed chunk" + ) + assert m2_final_val.startswith(m2_val), ( + "final val should contain the first streamed chunk" + ) + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value + + asyncio.run(parallel_requests()) + + +def test_async_avalue(m_session): + async def avalue(): + mot1 = m_session.backend.generate_from_context( + CBlock("Say Hello."), SimpleContext() + ) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value + + asyncio.run(avalue()) + + if __name__ == "__main__": import pytest diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index 85dedd66..0a9b917a 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -1,4 +1,5 @@ # test/rits_backend_tests/test_watsonx_integration.py +import asyncio import os import pydantic @@ -9,7 +10,7 @@ from mellea.backends.formatter import TemplateFormatter from mellea.backends.types import ModelOption from mellea.backends.watsonx import WatsonxAIBackend -from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk +from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk, SimpleContext @pytest.fixture(scope="module") @@ -41,14 +42,21 @@ def test_instruct(session: MelleaSession): assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore - +@pytest.mark.xfail(reason="watsonx python sdk has weird interactions with event loops; causes some errors with pytest.") @pytest.mark.qualitative def test_multiturn(session: MelleaSession): session.instruct("What is the capital of France?") answer = session.instruct("Tell me the answer to the previous question.") assert "Paris" in answer.value # type: ignore +@pytest.mark.qualitative +def test_chat(session): + output_message = session.chat("What is 1+1?") + assert "2" in output_message.content, ( + f"Expected a message with content containing 2 but found {output_message}" + ) +@pytest.mark.xfail(reason="watsonx python sdk has weird interactions with event loops; causes some errors with pytest.") @pytest.mark.qualitative def test_format(session: MelleaSession): class Person(pydantic.BaseModel): @@ -93,6 +101,46 @@ def test_generate_from_raw(session: MelleaSession): assert len(results) == len(prompts) +@pytest.mark.qualitative +def test_async_parallel_requests(session): + async def parallel_requests(): + model_opts = {ModelOption.STREAM: True} + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + + m1_val = None + m2_val = None + if not mot1.is_computed(): + m1_val = await mot1.astream() + if not mot2.is_computed(): + m2_val = await mot2.astream() + + assert m1_val is not None, "should be a string val after generation" + assert m2_val is not None, "should be a string val after generation" + + m1_final_val = await mot1.avalue() + m2_final_val = await mot2.avalue() + + # Ideally, we would be able to assert that m1_final_val != m1_val, but sometimes the first streaming response + # contains the full response. + assert m1_final_val.startswith(m1_val), "final val should contain the first streamed chunk" + assert m2_final_val.startswith(m2_val), "final val should contain the first streamed chunk" + + assert m1_final_val == mot1.value + assert m2_final_val == mot2.value + asyncio.run(parallel_requests()) + +# TODO: If this becomes a big issue, we will just have to re-instantiate the ModelInference object between requests. +# Ideally, we would only do this when creating a new m.session from the same backend. +@pytest.mark.xfail(reason="watsonx python sdk apparently doesn't support running across multiple async event loops.") +@pytest.mark.qualitative +def test_async_avalue(session): + async def avalue(): + mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + m1_final_val = await mot1.avalue() + assert m1_final_val is not None + assert m1_final_val == mot1.value + asyncio.run(avalue()) if __name__ == "__main__": import pytest diff --git a/test/stdlib_basics/test_requirement.py b/test/stdlib_basics/test_requirement.py index 12af105a..5d11b00a 100644 --- a/test/stdlib_basics/test_requirement.py +++ b/test/stdlib_basics/test_requirement.py @@ -1,9 +1,10 @@ +import asyncio import pytest -from mellea.stdlib.base import ModelOutputThunk +from mellea.stdlib.base import LinearContext, ModelOutputThunk from mellea.stdlib.requirement import Requirement, simple_validate -from mellea.stdlib.session import SimpleContext, start_session +from mellea.stdlib.session import start_session -ctx = SimpleContext() +ctx = LinearContext() ctx.insert(ModelOutputThunk("test")) def test_llmaj_validation_req_output_field(): @@ -11,7 +12,10 @@ def test_llmaj_validation_req_output_field(): req = Requirement("Must output test.") assert req._output is None - _ = req.validate(m.backend,ctx=ctx) + async def val(): + _ = await req.validate(m.backend,ctx=ctx) + asyncio.run(val()) + assert req._output is None, "requirement's output shouldn't be updated during/after validation" def test_simple_validate_bool(): diff --git a/test/stdlib_basics/test_sampling_ctx.py b/test/stdlib_basics/test_sampling_ctx.py index f178aae4..2e0ba033 100644 --- a/test/stdlib_basics/test_sampling_ctx.py +++ b/test/stdlib_basics/test_sampling_ctx.py @@ -1,3 +1,4 @@ +import pytest from mellea import LinearContext, start_session from mellea.backends import ModelOption from mellea.stdlib.sampling import ( @@ -61,3 +62,6 @@ def test_ctx_for_multiturn(self): self._run_asserts_for_ctx_testing(res) assert len(self.m.last_prompt()) == len(res.sample_generations)*2-1, "For n sampling iterations there should be 2n-1 prompt conversation elements in the last prompt." + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/test_tool_calls.py b/test/test_tool_calls.py index 4b966040..f2643cc6 100644 --- a/test/test_tool_calls.py +++ b/test/test_tool_calls.py @@ -32,27 +32,27 @@ def table() -> Table: def test_tool_called_from_context_action(m: MelleaSession, table: Table): """Make sure tools can be called from actions in the context.""" - r = 10 m.ctx.reset() # Insert a component with tools into the context. m.ctx.insert(table) - returned_tool = False - for i in range(r): - # Make sure the specific generate call is on a different action with - # no tools to make sure it's a tool from the context. - result = m.backend.generate_from_context( - CBlock("Add a row to the table."), - m.ctx, - tool_calls=True - ) - if result.tool_calls is not None and len(result.tool_calls) > 0: - returned_tool = True - break + # Create fake tools. + def test1(): ... + def test2(): ... - assert returned_tool, f"did not return a tool after {r} attempts" + model_opts = { + ModelOption.TOOLS: [test1, test2] + } + + tools = {} + + add_tools_from_model_options(tools, model_opts) + assert "test1" in tools + assert "test2" in tools + add_tools_from_context_actions(tools, m.ctx.actions_for_available_tools()) + assert "to_markdown" in tools def test_tool_called(m: MelleaSession, table: Table): """We don't force tools to be called. As a result, this test might unexpectedly fail."""