Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ repos:
- id: ruff
name: "Ruff linter"
args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml]
files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$'
files: '^(mellea|tests).*\.(py|ipynb)$'

- repo: local
hooks:
Expand Down
3 changes: 1 addition & 2 deletions cli/alora/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@


def upload_model(weight_path: str, model_name: str, private: bool = True):
"""
Upload a trained adapter (LoRA/aLoRA) to Hugging Face Hub.
"""Upload a trained adapter (LoRA/aLoRA) to Hugging Face Hub.

Args:
weight_path (str): Directory containing adapter weights (from save_pretrained).
Expand Down
2 changes: 1 addition & 1 deletion cli/m.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# Add a default callback for handling the default cli description.
@cli.callback()
def callback() -> None:
"""Perform M Tasks"""
"""Perform M Tasks."""


# Typer assumes that all commands are in the same file/module.
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/best_of_n/prm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Example of Using Best of N with PRMs"""
"""Example of Using Best of N with PRMs."""

from docs.examples.helper import w
from mellea import start_session
Expand Down
3 changes: 1 addition & 2 deletions docs/examples/generative_slots/generative_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@ def classify_sentiment(text: str) -> Literal["positive", "negative"]: ...

@generative
def generate_summary(text: str) -> str:
"""
This is a function that takes in a string and generates a summary for the string.
"""This is a function that takes in a string and generates a summary for the string.
Keep your summary succinct and under 20 words.
"""

Expand Down
4 changes: 1 addition & 3 deletions docs/examples/information_extraction/101_with_gen_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@

@generative
def extract_all_person_names(doc: str) -> list[str]:
"""
Given a document, extract all person names. Return these names as list of strings.
"""
"""Given a document, extract all person names. Return these names as list of strings."""


# ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html
Expand Down
14 changes: 6 additions & 8 deletions docs/examples/mini_researcher/researcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,20 +19,20 @@

@cache
def get_session():
"""get M session (change model here)"""
"""Get M session (change model here)."""
return MelleaSession(backend=OllamaModelBackend(model_ids.IBM_GRANITE_3_3_8B))


@cache
def get_guardian_session():
"""get M session for the guardian model"""
"""Get M session for the guardian model."""
return MelleaSession(
backend=OllamaModelBackend(model_ids.IBM_GRANITE_GUARDIAN_3_0_2B)
)


def is_a_true_subset_of_b(a: list[str], b: list[str]) -> bool:
"""check if a is true subset of b."""
"""Check if a is true subset of b."""
all_in = True
for e in a:
if e not in b:
Expand All @@ -42,7 +42,7 @@ def is_a_true_subset_of_b(a: list[str], b: list[str]) -> bool:


def create_check_word_count(max_words: int) -> Callable[[str], bool]:
"""generate a maximum-word-count validation function."""
"""Generate a maximum-word-count validation function."""

def cc(s: str):
return len(s.split()) <= max_words
Expand All @@ -56,7 +56,7 @@ def cc(s: str):


def step_is_input_safe(guardian_session: MelleaSession, docs: list[str]) -> bool:
"""check if the list of docs has no harm."""
"""Check if the list of docs has no harm."""
is_safe = True
for i_doc, doc in enumerate(docs):
print(f"\nChecking Doc {i_doc + 1}/{len(docs)}", end="...")
Expand All @@ -73,7 +73,7 @@ def step_is_input_safe(guardian_session: MelleaSession, docs: list[str]) -> bool
def step_summarize_docs(
s: MelleaSession, docs: list[str], user_args: dict
) -> list[str]:
"""generate a task-specific document summary for each doc."""
"""Generate a task-specific document summary for each doc."""
summaries = []
for i_doc, doc in enumerate(docs): # type: ignore
print(f"\nSummarizing doc {i_doc + 1}/{len(docs)}", end="...")
Expand All @@ -91,7 +91,6 @@ def step_generate_outline(
s: MelleaSession, user_args: dict, context: list[RAGDocument]
) -> list[str]:
"""Generate a report outline using constraint decoding (formatted output)."""

print("\n Generating outline", end="...")

class SectionTitles(BaseModel):
Expand Down Expand Up @@ -165,7 +164,6 @@ def step_write_full_report(
outline: list[str],
) -> str:
"""Merge summaries and outline into a single report."""

print("\nWriting full report", end="...")

## Define Requirements
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/notebooks/mcp_example.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"\n",
"@mcp.resource(\"greeting://{name}\")\n",
"def get_greeting(name: str) -> str:\n",
" \"\"\"Get a personalized greeting\"\"\"\n",
" \"\"\"Get a personalized greeting.\"\"\"\n",
" return f\"Hello, {name}!\""
]
}
Expand Down
2 changes: 1 addition & 1 deletion docs/examples/safety.py/guardian.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Example of using the Guardian Requirement"""
"""Example of using the Guardian Requirement."""

from mellea import MelleaSession
from mellea.backends import model_ids
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/tutorial/mcp_example.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Example of an MCP server
"""Example of an MCP server.

You need to install the mcp package:
uv pip install "mcp[cli]"
Expand Down Expand Up @@ -50,5 +50,5 @@ def write_a_poem(word_limit: int) -> str:

@mcp.resource("greeting://{name}")
def get_greeting(name: str) -> str:
"""Get a personalized greeting"""
"""Get a personalized greeting."""
return f"Hello, {name}!"
10 changes: 7 additions & 3 deletions mellea/backends/aloras/huggingface/granite_aloras.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,12 @@ def __init__(
"""Initialize after checking that the backend is correct.

Args:
name: name of the alora.
path_or_model_id: huggingface path or model id.
generation_prompt: the prompt required to activate the aLoRa.
backend: a LocalHFBackend that this alora is attached to.
constraint_prompt: a template that the constraint can be interpolated into; can only have a single `{}` slot.
include_constraint_in_alora_offset: whether to include the constraint prompt in the alora offset
include_constraint_in_alora_offset: whether to include the constraint prompt in the alora offset.
"""
super().__init__(name, path_or_model_id, generation_prompt, backend)

Expand Down Expand Up @@ -142,7 +146,6 @@ def _generate_using_cache(
self, cache_hit: HFAloraCacheInfo, constraint: str, force_yn: bool
) -> dict:
"""Returns the input object used for generation."""

# Must tokenize the constraint here since the requirement isn't known at initialization.
constraint_tokens = self._backend._tokenizer(
self._constraint_prompt.format(constraint), return_tensors="pt"
Expand Down Expand Up @@ -177,7 +180,6 @@ def _generate_not_using_cache(
self, input: str, response: str, constraint: str, force_yn: bool
) -> dict:
"""Returns the input object used for generation."""

# Params aren't needed when just getting the backend args.
backend_model_opts = self._backend._simplify_and_merge(None)
sys_prompt = backend_model_opts.get(ModelOption.SYSTEM_PROMPT, None)
Expand Down Expand Up @@ -226,6 +228,7 @@ async def processing(
force_yn: bool,
gen_prompt: str,
):
"""Called to process the incoming chunks."""
if mot._underlying_value is None:
mot._underlying_value = ""

Expand All @@ -248,6 +251,7 @@ async def processing(


async def post_processing(mot: ModelOutputThunk, backend: LocalHFBackend):
"""Called after all data has been received."""
backend.formatter.parse(mot._action, mot) # type: ignore


Expand Down
2 changes: 2 additions & 0 deletions mellea/backends/aloras/openai/granite_aloras.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,14 @@ def generate_using_strings(


async def processing(mot: ModelOutputThunk, chunk: Completion):
"""Called to process the incoming chunks."""
if mot._underlying_value is None:
mot._underlying_value = ""
mot._underlying_value += chunk.choices[0].text


async def post_processing(backend: OpenAIBackend, mot: ModelOutputThunk):
"""Called after all data has been received."""
backend.formatter.parse(mot._action, mot) # type: ignore


Expand Down
3 changes: 2 additions & 1 deletion mellea/backends/formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def parse(
) -> ModelOutputThunk:
"""Parses the output from a model and sets the parsed_repr of the result ModelOutputThunk.

Returns the ModelOutputThunk that was passed in."""
Returns the ModelOutputThunk that was passed in.
"""
...

def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]:
Expand Down
7 changes: 4 additions & 3 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -765,10 +765,12 @@ def __init__(


class HFProcessRewardModel(PRM, abc.ABC):
"""A Process Reward Model that works with a huggingface backend."""

def __init__(
self, model_name_or_path: str, score_token: str, device: str | None = None
):
"""Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models
"""Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models.

Args:
model_name_or_path (str): A local path to PRM or a huggingface PRM
Expand Down Expand Up @@ -803,13 +805,12 @@ def __init__(
)[0]

def stepify(self, content: str, step_separator: str) -> list[str]:
"""Splits the assistant response into steps to score
"""Splits the assistant response into steps to score.

Args:
content: assistant response to score
step_separator: string on which to separate the content into steps
"""

# convert assistant message into a list of steps
list_of_steps = [
step.strip() for step in content.split(step_separator) if step.strip != ""
Expand Down
3 changes: 2 additions & 1 deletion mellea/backends/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,8 @@ async def processing(
):
"""Called during generation to add information from a single ModelResponse or a chunk / ModelResponseStream to the ModelOutputThunk.

For LiteLLM, tool call parsing is handled in the post processing step."""
For LiteLLM, tool call parsing is handled in the post processing step.
"""
if mot._thinking is None:
mot._thinking = ""
if mot._underlying_value is None:
Expand Down
6 changes: 6 additions & 0 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,12 @@ async def post_processing(


def chat_response_delta_merge(mot: ModelOutputThunk, delta: ollama.ChatResponse):
"""Merges the individual ChatResponse chunks from a streaming response into a single ChatResponse.

Args:
mot: the ModelOutputThunk that the deltas are being used to populated.
delta: the most recent ollama ChatResponse.
"""
if mot._meta.get("chat_response", None) is None:
mot._meta["chat_response"] = delta
return # Return early, no need to merge.
Expand Down
7 changes: 6 additions & 1 deletion mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
model_options : Generation options to pass to the LLM. Defaults to None.
default_to_constraint_checking_alora: If set to False then aloras will be deactivated. This is primarily for performance benchmarking and debugging.
api_key : API key for generation. Defaults to None.
kwargs : additional kwargs to pass when creating the OpenAI client.
"""
super().__init__(
model_id=model_id,
Expand Down Expand Up @@ -220,6 +221,7 @@ def _simplify_and_merge(

Args:
model_options: the model_options for this call
is_chat_context: set to True if using chat completion api

Returns:
a new dict
Expand All @@ -245,6 +247,7 @@ def _make_backend_specific_and_remove(

Args:
model_options: the model_options for this call
is_chat_context: set to True if using chat completion api

Returns:
a new dict
Expand Down Expand Up @@ -372,6 +375,7 @@ def _generate_from_chat_context_alora(

@staticmethod
def message_to_openai_message(msg: Message):
"""Serializes a mellea Message object to the message format required by OpenAI compatible api providers."""
if msg.images is not None:
img_list = [
{
Expand Down Expand Up @@ -526,7 +530,8 @@ async def processing(
):
"""Called during generation to add information from a single ChatCompletion or ChatCompletionChunk to the ModelOutputThunk.

For OpenAI, tool call parsing is handled in the post processing step."""
For OpenAI, tool call parsing is handled in the post processing step.
"""
if mot._thinking is None:
mot._thinking = ""
if mot._underlying_value is None:
Expand Down
9 changes: 6 additions & 3 deletions mellea/backends/process_reward_models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,24 @@
"""Abstract interfaces for Backends that implement Process Reward Models (can be adapted to include other scorers)"""
"""Abstract interfaces for Backends that implement Process Reward Models (can be adapted to include other scorers)."""

import abc


class PRM(abc.ABC):
"""Mixin for Process Reward Model Backends."""

def __init__(self, model_name_or_path):
"""Sets the self.model_name_or_path. Inheriting classes should implement the remaining logic."""
# Leave implementation of model to inheriting class
self.model_name_or_path = model_name_or_path

@abc.abstractmethod
def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]:
"""Returns a final score and per-step score to the input of the model"""
"""Returns a final score and per-step score to the input of the model."""
...

@abc.abstractmethod
def stepify(self, response: str, step_separator: str) -> list[str]:
"""Splits the assistant response into steps to score
"""Splits the assistant response into steps to score.

Args:
response: assistant response to score
Expand Down
Original file line number Diff line number Diff line change
@@ -1 +1 @@
"""Process Reward Model Implementations with Huggingface backends"""
"""Process Reward Model Implementations with Huggingface backends."""
Loading