Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
# users should only be specifying a single one in their request.
self.to_mellea_model_opts_map = {
"system": ModelOption.SYSTEM_PROMPT,
"reasoning_effort": ModelOption.THINKING, # TODO: JAL; see which of these are actually extracted...
"reasoning_effort": ModelOption.THINKING,
"seed": ModelOption.SEED,
"max_completion_tokens": ModelOption.MAX_NEW_TOKENS,
"max_tokens": ModelOption.MAX_NEW_TOKENS,
Expand Down
110 changes: 80 additions & 30 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,14 @@
add_tools_from_model_options,
)
from mellea.backends.types import ModelOption
from mellea.helpers.async_helpers import send_to_queue
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
CBlock,
Component,
Context,
GenerateLog,
GenerateType,
ModelOutputThunk,
ModelToolCall,
TemplateRepresentation,
Expand Down Expand Up @@ -234,6 +236,7 @@ def generate_from_context(
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
stream: bool = False,
):
"""See `generate_from_chat_context`."""
assert ctx.is_chat_context, (
Expand All @@ -246,6 +249,7 @@ def generate_from_context(
model_options=model_options,
generate_logs=generate_logs,
tool_calls=tool_calls,
stream=stream,
)

def generate_from_chat_context(
Expand All @@ -257,11 +261,15 @@ def generate_from_chat_context(
model_options: dict | None = None,
generate_logs: list[GenerateLog] | None = None,
tool_calls: bool = False,
stream: bool = False,
) -> ModelOutputThunk:
"""Generates a new completion from the provided Context using this backend's `Formatter`.

This implementation treats the `Context` as a chat history, and uses the `ollama.Client.chat()` interface to generate a completion.
This will not always work, because sometimes we want to use non-chat models.

Raises:
RuntimeError: If not called from a thread with a running event loop.
"""
model_opts = self._simplify_and_merge(model_options)

Expand Down Expand Up @@ -311,45 +319,87 @@ def generate_from_chat_context(
FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}")

# Generate a chat response from ollama, using the chat messages.
chat_response: ollama.ChatResponse = self._client.chat(
chat_response = self._async_client.chat(
model=self._get_ollama_model_id(),
messages=conversation,
tools=list(tools.values()),
think=model_opts.get(ModelOption.THINKING, None),
options=self._make_backend_specific_and_remove(model_opts),
stream=False,
stream=stream,
format=format.model_json_schema() if format is not None else None,
) # type: ignore

result = ModelOutputThunk(
value=chat_response.message.content, # For an ollama tool call, content will be an empty string.
meta={"chat_response": chat_response},
tool_calls=self._extract_model_tool_requests(tools, chat_response),
)
output = ModelOutputThunk(None)
output._context = linearized_context
output._action = action

def processing(mot: ModelOutputThunk, chunk: ollama.ChatResponse):
"""Called during generation to add information from a single ChatResponse to the ModelOutputThunk."""
if mot._thinking is None:
mot._thinking = ""
thinking_chunk = chunk.message.thinking
if thinking_chunk is not None:
mot._thinking += thinking_chunk

if mot._underlying_value is None:
mot._underlying_value = ""
content_chunk = chunk.message.content
if content_chunk is not None:
mot._underlying_value += content_chunk

if mot.tool_calls is None:
mot.tool_calls = {}
tool_chunk = self._extract_model_tool_requests(tools, chunk)
if tool_chunk is not None:
# Merge the tool_chunk dict.
for key, val in tool_chunk.items():
mot.tool_calls[key] = val

output._process = processing

formatted_result = self.formatter.parse(action, result)

if generate_logs is not None:
# noinspection DuplicatedCode
assert isinstance(generate_logs, list)
generate_log = GenerateLog()
generate_log.prompt = conversation
generate_log.backend = f"ollama::{self.model_id!s}"
generate_log.model_options = model_opts
generate_log.date = datetime.datetime.now()
generate_log.model_output = chat_response
generate_log.extra = {
"format": format,
"thinking": model_opts.get(ModelOption.THINKING, None),
"tools_available": tools,
"tools_called": result.tool_calls,
"seed": model_opts.get(ModelOption.SEED, None),
}
generate_log.action = action
generate_log.result = formatted_result
generate_logs.append(generate_log)

return formatted_result
def post_processing(mot: ModelOutputThunk):
"""Called when generation is done."""
self.formatter.parse(action, mot)

output._post_process = post_processing

try:
# To support lazy computation, will need to remove this create_task and store just the unexecuted coroutine.
# We can also support synchronous calls by adding a flag and changing this ._generate function.

# This function should always be called from a running event loop so we don't have to worry about
# scheduling the task to a specific event loop here.
output._generate = asyncio.create_task(
send_to_queue(chat_response, output._async_queue)
)
output._generate_type = GenerateType.ASYNC
except RuntimeError as e:
# Most likely cause is running this function without an event loop present
raise e

return output

# if generate_logs is not None:
# # noinspection DuplicatedCode
# assert isinstance(generate_logs, list)
# generate_log = GenerateLog()
# generate_log.prompt = conversation
# generate_log.backend = f"ollama::{self.model_id!s}"
# generate_log.model_options = model_opts
# generate_log.date = datetime.datetime.now()
# generate_log.model_output = chat_response
# generate_log.extra = {
# "format": format,
# "thinking": model_opts.get(ModelOption.THINKING, None),
# "tools_available": tools,
# "tools_called": result.tool_calls,
# "seed": model_opts.get(ModelOption.SEED, None),
# }
# generate_log.action = action
# generate_log.result = formatted_result
# generate_logs.append(generate_log)

# return formatted_result

def _generate_from_raw(
self,
Expand Down
21 changes: 21 additions & 0 deletions mellea/helpers/async_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import asyncio
from collections.abc import AsyncIterator, Coroutine
from typing import Any


async def send_to_queue(
co: Coroutine[Any, Any, AsyncIterator] | Coroutine[Any, Any, Any],
aqueue: asyncio.Queue,
) -> None:
"""Processes the output of an async chat request by sending the output to an async queue."""
aresponse = await co

if isinstance(aresponse, AsyncIterator):
async for item in aresponse:
await aqueue.put(item)

else:
await aqueue.put(aresponse)

# Always add a sentinel value to indicate end of stream.
await aqueue.put(None)
119 changes: 117 additions & 2 deletions mellea/stdlib/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@
from __future__ import annotations

import abc
import asyncio
import base64
import binascii
import datetime
import enum
from collections.abc import Callable, Iterable, Mapping
from copy import deepcopy
from dataclasses import dataclass
Expand Down Expand Up @@ -131,7 +133,7 @@ def format_for_llm(self) -> TemplateRepresentation | str:
def get_images_from_component(c: Component) -> None | list[ImageBlock]:
"""Gets images from a `Component` if they are present and a non-empty list, otherwise returns None."""
if hasattr(c, "images"):
imgs = c.images
imgs = c.images # type: ignore
if imgs is not None:
assert isinstance(imgs, list), "images field must be a list."
assert all(isinstance(im, ImageBlock) for im in imgs), (
Expand All @@ -147,6 +149,14 @@ def get_images_from_component(c: Component) -> None | list[ImageBlock]:
return None


class GenerateType(enum.Enum):
"""Used to track what functions can be used to extract a value from a ModelOutputThunk."""

NONE = None
ASYNC = 1
SYNC = 2


class ModelOutputThunk(CBlock):
"""A `ModelOutputThunk` is a special type of `CBlock` that we know came from a model's output. It is possible to instantiate one without the output being computed yet."""

Expand All @@ -160,11 +170,116 @@ def __init__(
"""Initializes as a cblock, optionally also with a parsed representation from an output formatter."""
super().__init__(value, meta)
self.parsed_repr: CBlock | Component | Any | None = parsed_repr

# Set computed to True if a value is passed in.
self._computed: bool = True if value is not None else False

# Additional fields that should be standardized across apis.
self.tool_calls = tool_calls
self._thinking: str | None = None

# Used for tracking generation.
self._context: list[Component | CBlock] | None = None
self._action: Component | CBlock | None = None
self._model_options: dict[str, Any] | None = None

# Used for async and async streaming.
self._async_queue: asyncio.Queue = asyncio.Queue(maxsize=20)
self._chunk_size = 3 # Minimum number of chunks to stream at a single time.

# _generate and _generate_type are linked. _generate will determine
# what gets set for _generate_type. _generate_type determines what
# function(s) can be used to get the value of the ModelOutputThunk.
self._generate: asyncio.Task[None] | None = None
self._generate_type: GenerateType = GenerateType.NONE
self._process: Callable[[ModelOutputThunk, Any], None] | None = None
self._post_process: Callable[[ModelOutputThunk], None] | None = None

def is_computed(self):
"""Returns true only if this Thunk has already been filled."""
return self.value is not None
return self._computed

@property
def value(self) -> str | None:
"""Gets the value of the block."""
if not self._computed:
return None
return self._underlying_value

@value.setter
def value(self, v: str):
"""Sets the value of the block."""
self._underlying_value = v

async def avalue(self) -> str:
"""Returns the value of the ModelOutputThunk. Can be used for both async streaming and async non-streaming.

Raises:
RuntimeError: If called when the ModelOutputThunk's generate function is not async compatible.
"""
if self._computed:
assert self.value # If computed, the value cannot be None.
return self.value

if not self._generate_type == GenerateType.ASYNC:
raise RuntimeError(
f"Cannot use `ModelOutputThunk.avalue()` when the generate function is using `{self._generate_type.name}`"
)

while not self._computed:
await self.astream()

assert self.value is not None # If computed, the value cannot be None.
return self.value

async def astream(self) -> str | None:
"""Returns the next chunk of data. Can be used for both async streaming and async non-streaming.

Returns `None` if the ModelOutputThunk is already computed or no values are left.

Raises:
RuntimeError: If called when the ModelOutputThunk's generate function is not async compatible.
"""
if self._computed:
# Return an empty item since there's nothing more to stream.
return None

if not self._generate_type == GenerateType.ASYNC:
raise RuntimeError(
f"Cannot use `ModelOutputThunk.astream()` when the generate function is using `{self._generate_type.name}`"
)

# Type of the chunk depends on the backend.
chunks: list[Any | None] = []
while True:
try:
item = self._async_queue.get_nowait()
chunks.append(item)
except asyncio.QueueEmpty:
# We've exhausted the current items in the queue.
break

# Make sure we always get the minimum chunk size.
while len(chunks) <= self._chunk_size:
if len(chunks) > 0 and chunks[-1] is None:
break # Hit sentinel value.

item = await self._async_queue.get()
chunks.append(item)

# Process the sentinel value if it's there.
if chunks[-1] is None:
chunks.pop() # Remove the sentinel value.
self._computed = True
for chunk in chunks:
assert self._process is not None
self._process(self, chunk)

if self._computed:
assert self._post_process is not None
self._post_process(self)

return self._underlying_value


def blockify(s: str | CBlock | Component) -> CBlock | Component:
Expand Down
3 changes: 2 additions & 1 deletion mellea/stdlib/requirement.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
# Used for validation. Do not manually populate.
self._output: str | None = None

def validate(
async def validate(
self,
backend: Backend,
ctx: Context,
Expand Down Expand Up @@ -133,6 +133,7 @@ def validate(
model_options=model_options,
generate_logs=generate_logs,
)
await llm_as_a_judge_result.avalue()
return ValidationResult(
result=self.output_to_bool(llm_as_a_judge_result),
reason=llm_as_a_judge_result.value,
Expand Down
Loading
Loading