diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 00230cb0..ae8b1847 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -11,7 +11,7 @@ repos: - id: ruff name: "Ruff linter" args: [--exit-non-zero-on-fix, --fix, --config=pyproject.toml] - files: '^(mellea|tests|cli|docs).*\.(py|ipynb)$' + files: '^(mellea|tests).*\.(py|ipynb)$' - repo: local hooks: diff --git a/cli/alora/upload.py b/cli/alora/upload.py index 475bed86..51491da2 100644 --- a/cli/alora/upload.py +++ b/cli/alora/upload.py @@ -5,8 +5,7 @@ def upload_model(weight_path: str, model_name: str, private: bool = True): - """ - Upload a trained adapter (LoRA/aLoRA) to Hugging Face Hub. + """Upload a trained adapter (LoRA/aLoRA) to Hugging Face Hub. Args: weight_path (str): Directory containing adapter weights (from save_pretrained). diff --git a/cli/m.py b/cli/m.py index a5ce9d2a..3aa32aa1 100644 --- a/cli/m.py +++ b/cli/m.py @@ -12,7 +12,7 @@ # Add a default callback for handling the default cli description. @cli.callback() def callback() -> None: - """Perform M Tasks""" + """Perform M Tasks.""" # Typer assumes that all commands are in the same file/module. diff --git a/docs/examples/best_of_n/prm.py b/docs/examples/best_of_n/prm.py index 4c08c867..d1843810 100644 --- a/docs/examples/best_of_n/prm.py +++ b/docs/examples/best_of_n/prm.py @@ -1,4 +1,4 @@ -"""Example of Using Best of N with PRMs""" +"""Example of Using Best of N with PRMs.""" from docs.examples.helper import w from mellea import start_session diff --git a/docs/examples/generative_slots/generative_slots.py b/docs/examples/generative_slots/generative_slots.py index 2b053d54..7006b965 100644 --- a/docs/examples/generative_slots/generative_slots.py +++ b/docs/examples/generative_slots/generative_slots.py @@ -9,8 +9,7 @@ def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... @generative def generate_summary(text: str) -> str: - """ - This is a function that takes in a string and generates a summary for the string. + """This is a function that takes in a string and generates a summary for the string. Keep your summary succinct and under 20 words. """ diff --git a/docs/examples/information_extraction/101_with_gen_slots.py b/docs/examples/information_extraction/101_with_gen_slots.py index 789c03df..0a3c57a7 100644 --- a/docs/examples/information_extraction/101_with_gen_slots.py +++ b/docs/examples/information_extraction/101_with_gen_slots.py @@ -8,9 +8,7 @@ @generative def extract_all_person_names(doc: str) -> list[str]: - """ - Given a document, extract all person names. Return these names as list of strings. - """ + """Given a document, extract all person names. Return these names as list of strings.""" # ref: https://www.nytimes.com/2012/05/20/world/world-leaders-at-us-meeting-urge-growth-not-austerity.html diff --git a/docs/examples/mini_researcher/researcher.py b/docs/examples/mini_researcher/researcher.py index 22db0e2e..c7f93a5e 100644 --- a/docs/examples/mini_researcher/researcher.py +++ b/docs/examples/mini_researcher/researcher.py @@ -19,20 +19,20 @@ @cache def get_session(): - """get M session (change model here)""" + """Get M session (change model here).""" return MelleaSession(backend=OllamaModelBackend(model_ids.IBM_GRANITE_3_3_8B)) @cache def get_guardian_session(): - """get M session for the guardian model""" + """Get M session for the guardian model.""" return MelleaSession( backend=OllamaModelBackend(model_ids.IBM_GRANITE_GUARDIAN_3_0_2B) ) def is_a_true_subset_of_b(a: list[str], b: list[str]) -> bool: - """check if a is true subset of b.""" + """Check if a is true subset of b.""" all_in = True for e in a: if e not in b: @@ -42,7 +42,7 @@ def is_a_true_subset_of_b(a: list[str], b: list[str]) -> bool: def create_check_word_count(max_words: int) -> Callable[[str], bool]: - """generate a maximum-word-count validation function.""" + """Generate a maximum-word-count validation function.""" def cc(s: str): return len(s.split()) <= max_words @@ -56,7 +56,7 @@ def cc(s: str): def step_is_input_safe(guardian_session: MelleaSession, docs: list[str]) -> bool: - """check if the list of docs has no harm.""" + """Check if the list of docs has no harm.""" is_safe = True for i_doc, doc in enumerate(docs): print(f"\nChecking Doc {i_doc + 1}/{len(docs)}", end="...") @@ -73,7 +73,7 @@ def step_is_input_safe(guardian_session: MelleaSession, docs: list[str]) -> bool def step_summarize_docs( s: MelleaSession, docs: list[str], user_args: dict ) -> list[str]: - """generate a task-specific document summary for each doc.""" + """Generate a task-specific document summary for each doc.""" summaries = [] for i_doc, doc in enumerate(docs): # type: ignore print(f"\nSummarizing doc {i_doc + 1}/{len(docs)}", end="...") @@ -91,7 +91,6 @@ def step_generate_outline( s: MelleaSession, user_args: dict, context: list[RAGDocument] ) -> list[str]: """Generate a report outline using constraint decoding (formatted output).""" - print("\n Generating outline", end="...") class SectionTitles(BaseModel): @@ -165,7 +164,6 @@ def step_write_full_report( outline: list[str], ) -> str: """Merge summaries and outline into a single report.""" - print("\nWriting full report", end="...") ## Define Requirements diff --git a/docs/examples/notebooks/mcp_example.ipynb b/docs/examples/notebooks/mcp_example.ipynb index 315487be..438037dd 100644 --- a/docs/examples/notebooks/mcp_example.ipynb +++ b/docs/examples/notebooks/mcp_example.ipynb @@ -126,7 +126,7 @@ "\n", "@mcp.resource(\"greeting://{name}\")\n", "def get_greeting(name: str) -> str:\n", - " \"\"\"Get a personalized greeting\"\"\"\n", + " \"\"\"Get a personalized greeting.\"\"\"\n", " return f\"Hello, {name}!\"" ] } diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index 989b0777..7711df6f 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -1,4 +1,4 @@ -"""Example of using the Guardian Requirement""" +"""Example of using the Guardian Requirement.""" from mellea import MelleaSession from mellea.backends import model_ids diff --git a/docs/examples/tutorial/mcp_example.py b/docs/examples/tutorial/mcp_example.py index 4904ad56..48c042b5 100644 --- a/docs/examples/tutorial/mcp_example.py +++ b/docs/examples/tutorial/mcp_example.py @@ -1,4 +1,4 @@ -"""Example of an MCP server +"""Example of an MCP server. You need to install the mcp package: uv pip install "mcp[cli]" @@ -50,5 +50,5 @@ def write_a_poem(word_limit: int) -> str: @mcp.resource("greeting://{name}") def get_greeting(name: str) -> str: - """Get a personalized greeting""" + """Get a personalized greeting.""" return f"Hello, {name}!" diff --git a/mellea/backends/aloras/huggingface/granite_aloras.py b/mellea/backends/aloras/huggingface/granite_aloras.py index 2e1e7284..b5e29a47 100644 --- a/mellea/backends/aloras/huggingface/granite_aloras.py +++ b/mellea/backends/aloras/huggingface/granite_aloras.py @@ -33,8 +33,12 @@ def __init__( """Initialize after checking that the backend is correct. Args: + name: name of the alora. + path_or_model_id: huggingface path or model id. + generation_prompt: the prompt required to activate the aLoRa. + backend: a LocalHFBackend that this alora is attached to. constraint_prompt: a template that the constraint can be interpolated into; can only have a single `{}` slot. - include_constraint_in_alora_offset: whether to include the constraint prompt in the alora offset + include_constraint_in_alora_offset: whether to include the constraint prompt in the alora offset. """ super().__init__(name, path_or_model_id, generation_prompt, backend) @@ -142,7 +146,6 @@ def _generate_using_cache( self, cache_hit: HFAloraCacheInfo, constraint: str, force_yn: bool ) -> dict: """Returns the input object used for generation.""" - # Must tokenize the constraint here since the requirement isn't known at initialization. constraint_tokens = self._backend._tokenizer( self._constraint_prompt.format(constraint), return_tensors="pt" @@ -177,7 +180,6 @@ def _generate_not_using_cache( self, input: str, response: str, constraint: str, force_yn: bool ) -> dict: """Returns the input object used for generation.""" - # Params aren't needed when just getting the backend args. backend_model_opts = self._backend._simplify_and_merge(None) sys_prompt = backend_model_opts.get(ModelOption.SYSTEM_PROMPT, None) @@ -226,6 +228,7 @@ async def processing( force_yn: bool, gen_prompt: str, ): + """Called to process the incoming chunks.""" if mot._underlying_value is None: mot._underlying_value = "" @@ -248,6 +251,7 @@ async def processing( async def post_processing(mot: ModelOutputThunk, backend: LocalHFBackend): + """Called after all data has been received.""" backend.formatter.parse(mot._action, mot) # type: ignore diff --git a/mellea/backends/aloras/openai/granite_aloras.py b/mellea/backends/aloras/openai/granite_aloras.py index a6d17172..440faa92 100644 --- a/mellea/backends/aloras/openai/granite_aloras.py +++ b/mellea/backends/aloras/openai/granite_aloras.py @@ -105,12 +105,14 @@ def generate_using_strings( async def processing(mot: ModelOutputThunk, chunk: Completion): + """Called to process the incoming chunks.""" if mot._underlying_value is None: mot._underlying_value = "" mot._underlying_value += chunk.choices[0].text async def post_processing(backend: OpenAIBackend, mot: ModelOutputThunk): + """Called after all data has been received.""" backend.formatter.parse(mot._action, mot) # type: ignore diff --git a/mellea/backends/formatter.py b/mellea/backends/formatter.py index 0a83163c..7a060f42 100644 --- a/mellea/backends/formatter.py +++ b/mellea/backends/formatter.py @@ -37,7 +37,8 @@ def parse( ) -> ModelOutputThunk: """Parses the output from a model and sets the parsed_repr of the result ModelOutputThunk. - Returns the ModelOutputThunk that was passed in.""" + Returns the ModelOutputThunk that was passed in. + """ ... def to_chat_messages(self, cs: list[Component | CBlock]) -> list[Message]: diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 5903b9c2..598e837f 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -765,10 +765,12 @@ def __init__( class HFProcessRewardModel(PRM, abc.ABC): + """A Process Reward Model that works with a huggingface backend.""" + def __init__( self, model_name_or_path: str, score_token: str, device: str | None = None ): - """Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models + """Initialize an PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models. Args: model_name_or_path (str): A local path to PRM or a huggingface PRM @@ -803,13 +805,12 @@ def __init__( )[0] def stepify(self, content: str, step_separator: str) -> list[str]: - """Splits the assistant response into steps to score + """Splits the assistant response into steps to score. Args: content: assistant response to score step_separator: string on which to separate the content into steps """ - # convert assistant message into a list of steps list_of_steps = [ step.strip() for step in content.split(step_separator) if step.strip != "" diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 7716c785..61e495aa 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -321,7 +321,8 @@ async def processing( ): """Called during generation to add information from a single ModelResponse or a chunk / ModelResponseStream to the ModelOutputThunk. - For LiteLLM, tool call parsing is handled in the post processing step.""" + For LiteLLM, tool call parsing is handled in the post processing step. + """ if mot._thinking is None: mot._thinking = "" if mot._underlying_value is None: diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index a9e779b5..6d0190ac 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -539,6 +539,12 @@ async def post_processing( def chat_response_delta_merge(mot: ModelOutputThunk, delta: ollama.ChatResponse): + """Merges the individual ChatResponse chunks from a streaming response into a single ChatResponse. + + Args: + mot: the ModelOutputThunk that the deltas are being used to populated. + delta: the most recent ollama ChatResponse. + """ if mot._meta.get("chat_response", None) is None: mot._meta["chat_response"] = delta return # Return early, no need to merge. diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index fe6b1505..d9eb58a7 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -93,6 +93,7 @@ def __init__( model_options : Generation options to pass to the LLM. Defaults to None. default_to_constraint_checking_alora: If set to False then aloras will be deactivated. This is primarily for performance benchmarking and debugging. api_key : API key for generation. Defaults to None. + kwargs : additional kwargs to pass when creating the OpenAI client. """ super().__init__( model_id=model_id, @@ -220,6 +221,7 @@ def _simplify_and_merge( Args: model_options: the model_options for this call + is_chat_context: set to True if using chat completion api Returns: a new dict @@ -245,6 +247,7 @@ def _make_backend_specific_and_remove( Args: model_options: the model_options for this call + is_chat_context: set to True if using chat completion api Returns: a new dict @@ -372,6 +375,7 @@ def _generate_from_chat_context_alora( @staticmethod def message_to_openai_message(msg: Message): + """Serializes a mellea Message object to the message format required by OpenAI compatible api providers.""" if msg.images is not None: img_list = [ { @@ -526,7 +530,8 @@ async def processing( ): """Called during generation to add information from a single ChatCompletion or ChatCompletionChunk to the ModelOutputThunk. - For OpenAI, tool call parsing is handled in the post processing step.""" + For OpenAI, tool call parsing is handled in the post processing step. + """ if mot._thinking is None: mot._thinking = "" if mot._underlying_value is None: diff --git a/mellea/backends/process_reward_models/__init__.py b/mellea/backends/process_reward_models/__init__.py index ae911f98..5097d5f3 100644 --- a/mellea/backends/process_reward_models/__init__.py +++ b/mellea/backends/process_reward_models/__init__.py @@ -1,21 +1,24 @@ -"""Abstract interfaces for Backends that implement Process Reward Models (can be adapted to include other scorers)""" +"""Abstract interfaces for Backends that implement Process Reward Models (can be adapted to include other scorers).""" import abc class PRM(abc.ABC): + """Mixin for Process Reward Model Backends.""" + def __init__(self, model_name_or_path): + """Sets the self.model_name_or_path. Inheriting classes should implement the remaining logic.""" # Leave implementation of model to inheriting class self.model_name_or_path = model_name_or_path @abc.abstractmethod def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: - """Returns a final score and per-step score to the input of the model""" + """Returns a final score and per-step score to the input of the model.""" ... @abc.abstractmethod def stepify(self, response: str, step_separator: str) -> list[str]: - """Splits the assistant response into steps to score + """Splits the assistant response into steps to score. Args: response: assistant response to score diff --git a/mellea/backends/process_reward_models/huggingface/__init__.py b/mellea/backends/process_reward_models/huggingface/__init__.py index 3b259046..35adf9ac 100644 --- a/mellea/backends/process_reward_models/huggingface/__init__.py +++ b/mellea/backends/process_reward_models/huggingface/__init__.py @@ -1 +1 @@ -"""Process Reward Model Implementations with Huggingface backends""" +"""Process Reward Model Implementations with Huggingface backends.""" diff --git a/mellea/backends/process_reward_models/huggingface/prms.py b/mellea/backends/process_reward_models/huggingface/prms.py index 2bac7afb..2525b8e6 100644 --- a/mellea/backends/process_reward_models/huggingface/prms.py +++ b/mellea/backends/process_reward_models/huggingface/prms.py @@ -1,3 +1,5 @@ +"""PRM Implementations for Local HuggingFace Backends.""" + import torch from transformers.tokenization_utils_base import BatchEncoding @@ -5,6 +7,8 @@ class HFGenerativePRM(HFProcessRewardModel): + """A Generative PRM that works with a huggingface backend.""" + def __init__( self, model_name_or_path: str = "ibm-granite/granite-3.3-8b-lora-math-prm", @@ -13,7 +17,7 @@ def __init__( generation_prompt: str = "Is this response correct so far (Y/N)?", step_separator: str = "\n\n", ): - """Initialize a Generative PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models + """Initialize a Generative PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models. Args: model_name_or_path (str): A local path to PRM or a huggingface PRM @@ -30,13 +34,12 @@ def __init__( self.softmax = torch.nn.Softmax(dim=-1) def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: - """Returns a final and per-step score for a given input query and response + """Returns a final and per-step score for a given input query and response. Args: query (str): User query response (str): Assistant Response to score """ - list_of_steps = self.stepify(response, self.step_separator) # get tokenized batch batches = self.prepare_inputs(query, list_of_steps) @@ -114,7 +117,7 @@ def score(self, query: str, response: str) -> tuple[list[float], list[list[float return all_rewards, all_rewards_per_step def prepare_inputs(self, user_content: str, steps: list[str]) -> BatchEncoding: - """Prepare the inputs for inference with the model + """Prepare the inputs for inference with the model. Args: user_content (str): the user query @@ -153,6 +156,8 @@ def prepare_inputs(self, user_content: str, steps: list[str]) -> BatchEncoding: class HFRegressionPRM(HFProcessRewardModel): + """A Regression PRM that works with a huggingface backend.""" + def __init__( self, model_name_or_path: str, @@ -160,12 +165,12 @@ def __init__( device: str | None = None, step_separator: str = "\n\n", ): - """Initialize a Regression PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models + """Initialize a Regression PRM that works with a huggingface backend. Currently supports and tested with IBM Process Reward Models. Args: model_name_or_path (str): A local path to PRM or a huggingface PRM score_token (str): token who's logits correspond to the PRM score. Usually is a step demarker (for non-generative PRMs) - backend (LocalHFBackend): Mained as a pointer to the backend to which this this PRM is attached. + device (str): pointer to the device on which to run the model step_separator (str): string on which to separate the input content into steps """ super().__init__(model_name_or_path, score_token, device) @@ -189,13 +194,12 @@ def __init__( self.softmax = torch.nn.Softmax(dim=-1) def score(self, query: str, response: str) -> tuple[list[float], list[list[float]]]: - """Returns a final and per-step score for a given input query and response + """Returns a final and per-step score for a given input query and response. Args: query (str): User query response (str): Assistant Response to score """ - list_of_steps = self.stepify(response, self.step_separator) # tokenizes the batch and concatenates the list of steps into a single step-separated response batch = self.prepare_inputs(query, list_of_steps) @@ -232,7 +236,7 @@ def score(self, query: str, response: str) -> tuple[list[float], list[list[float return rewards, prm_probs def prepare_inputs(self, user_content: str, steps: list[str]) -> BatchEncoding: - """Prepare the inputs for inference with the model + """Prepare the inputs for inference with the model. Args: user_content (str): the user query diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index 9e3d9288..31d33a4a 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -64,8 +64,8 @@ def __init__( model_options : Global model options to pass to the model. Defaults to None. api_key : watsonx API key. Defaults to None. project_id : watsonx project ID. Defaults to None. + kwargs : extra kwargs passed to model inference creation. """ - # There are bugs with the Watsonx python sdk related to async event loops; # using the same watsonx backend across multiple event loops causes errors. warnings.warn( @@ -176,6 +176,7 @@ def _simplify_and_merge( Args: model_options: the model_options for this call + is_chat_context: set to True if used for chat completion apis Returns: a new dict @@ -201,6 +202,7 @@ def _make_backend_specific_and_remove( Args: model_options: the model_options for this call + is_chat_context: set to True if used for chat completion apis Returns: a new dict @@ -369,7 +371,8 @@ def generate_from_chat_context( async def processing(self, mot: ModelOutputThunk, chunk: dict): """Called during generation to add information from a single ChatCompletion or ChatCompletionChunk to the ModelOutputThunk. - For OpenAI-like APIs, tool call parsing is handled in the post processing step.""" + For OpenAI-like APIs, tool call parsing is handled in the post processing step. + """ if mot._thinking is None: mot._thinking = "" if mot._underlying_value is None: diff --git a/mellea/helpers/async_helpers.py b/mellea/helpers/async_helpers.py index 0d3e1866..2f6c3c61 100644 --- a/mellea/helpers/async_helpers.py +++ b/mellea/helpers/async_helpers.py @@ -1,3 +1,5 @@ +"""Async helper functions.""" + import asyncio from collections.abc import AsyncIterator, Coroutine from typing import Any @@ -37,7 +39,8 @@ async def wait_for_all_mots(mots: list[ModelOutputThunk]): """Helper function to make waiting for multiple ModelOutputThunks to be computed easier. All ModelOutputThunks must be from the same event loop. This should always be the case in sampling - functions, session functions, and top-level mellea functions.""" + functions, session functions, and top-level mellea functions. + """ coroutines: list[Coroutine[Any, Any, str]] = [] for mot in mots: coroutines.append(mot.avalue()) diff --git a/mellea/helpers/openai_compatible_helpers.py b/mellea/helpers/openai_compatible_helpers.py index 87daf0bc..38befda4 100644 --- a/mellea/helpers/openai_compatible_helpers.py +++ b/mellea/helpers/openai_compatible_helpers.py @@ -46,7 +46,6 @@ def chat_completion_delta_merge( chunks: the list of dicts that represent the message deltas force_all_tool_calls_separate: if `True`, tool calls in separate message deltas will not be merged (even if their index values are the same); use when providers do not return the correct index value for tool calls. If using this option, all tool calls must be fully populated in a single delta since they won't be merged. """ - merged: dict[str, Any] = dict() # `delta`s map to a single choice. diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index dc09c23c..2cb8daa6 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -318,7 +318,8 @@ async def astream(self) -> str: def __repr__(self): """Provides a python-parsable representation (usually). - Differs from CBlock because `._meta` can be very large for ModelOutputThunks.""" + Differs from CBlock because `._meta` can be very large for ModelOutputThunks. + """ return f"ModelOutputThunk({self.value})" @@ -371,7 +372,6 @@ def from_previous( cls: type[ContextT], previous: Context, data: Component | CBlock ) -> ContextT: """Constructs a new context from an existing context.""" - assert isinstance(previous, Context), ( "Cannot create a new context from a non-Context object." ) @@ -422,7 +422,8 @@ def is_chat_context(self) -> bool: def as_list(self, last_n_components: int | None = None) -> list[Component | CBlock]: """Returns a list of the last n components in the context sorted from FIRST TO LAST. - If `last_n_components` is `None`, then all components are returned.""" + If `last_n_components` is `None`, then all components are returned. + """ context_list: list[Component | CBlock] = [] current_context: Context = self @@ -455,7 +456,6 @@ def actions_for_available_tools(self) -> list[Component | CBlock] | None: def last_output(self, check_last_n_components: int = 3) -> ModelOutputThunk | None: """The last output thunk of the context.""" - for c in self.as_list(last_n_components=check_last_n_components)[::-1]: if isinstance(c, ModelOutputThunk): return c @@ -466,7 +466,6 @@ def last_turn(self): This can be partial. If the last event is an input, then the output is None. """ - history = self.as_list(last_n_components=2) if len(history) == 0: @@ -506,11 +505,13 @@ def __init__(self, *, window_size: int | None = None): self._window_size = window_size def add(self, c: Component | CBlock) -> ChatContext: + """Add a new component/cblock to the context. Returns the new context.""" new = ChatContext.from_previous(self, c) new._window_size = self._window_size return new def view_for_generation(self) -> list[Component | CBlock] | None: + """Returns the context in a linearized form. Uses the window_size set during initialization.""" return self.as_list(self._window_size) @@ -518,9 +519,11 @@ class SimpleContext(Context): """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" def add(self, c: Component | CBlock) -> SimpleContext: + """Add a new component/cblock to the context. Returns the new context.""" return SimpleContext.from_previous(self, c) def view_for_generation(self) -> list[Component | CBlock] | None: + """Returns an empty list.""" return [] diff --git a/mellea/stdlib/chat.py b/mellea/stdlib/chat.py index 084727b9..7f5bbb4a 100644 --- a/mellea/stdlib/chat.py +++ b/mellea/stdlib/chat.py @@ -32,6 +32,7 @@ def __init__( Args: role (str): The role that this message came from (e.g., user, assistant). content (str): The content of the message. + images (list[ImageBlock]): The images associated with the message if any. """ self.role = role self.content = content diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index 5e9ff9df..6ce6e314 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -90,7 +90,6 @@ def act( Returns: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - out = _run_async_in_thread( _act( action, @@ -279,7 +278,6 @@ def instruct( tool_calls: If true, tool calling is enabled. images: A list of images to be used in the instruction or None if none. """ - requirements = [] if requirements is None else requirements icl_examples = [] if icl_examples is None else icl_examples grounding_context = dict() if grounding_context is None else grounding_context @@ -489,10 +487,12 @@ def transform( """Transform method for creating a new object with the transformation applied. Args: - obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + obj: The object to be queried. It should be an instance of MObject or can be converted to one if necessary. transformation: The string representing the query to be executed against the object. context: the context being used as a history from which to generate the response. backend: the backend used to generate the response. + format: format for output parsing; usually not needed with transform. + model_options: Model options to pass to the backend. Returns: ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 9040255f..1e822871 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -163,7 +163,9 @@ def __call__( Args: m: MelleaSession: A mellea session (optional, uses context if None) - **kwargs: Additional Kwargs to be passed to the func + model_options: Model options to pass to the backend. + *args: Additional args to be passed to the func. + **kwargs: Additional Kwargs to be passed to the func. Returns: ModelOutputThunk: Output with generated Thunk. diff --git a/mellea/stdlib/mify.py b/mellea/stdlib/mify.py index 5b278f2d..ec5e5c25 100644 --- a/mellea/stdlib/mify.py +++ b/mellea/stdlib/mify.py @@ -229,7 +229,7 @@ def mify( ) -> T: ... # Overloads for @mify and mify(obj|cls) -def mify(*args, **kwargs): +def mify(*args, **kwargs): # noqa: D417 """M-ify an object or class. Allows the object (or instances of the class) to be used in m sessions and with m functions. diff --git a/mellea/stdlib/reqlib/md.py b/mellea/stdlib/reqlib/md.py index 904bd4ae..3cee2770 100644 --- a/mellea/stdlib/reqlib/md.py +++ b/mellea/stdlib/reqlib/md.py @@ -9,6 +9,7 @@ def as_markdown_list(ctx: Context) -> list[str] | None: + """Attempts to format the last_output of the given context as a markdown list.""" xs = list() raw_output = ctx.last_output() assert raw_output is not None diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index 64747e38..f10a3aaf 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -65,21 +65,25 @@ def __init__( @property def reason(self) -> str | None: + """Reason for the validation result.""" return self._reason @property def score(self) -> float | None: + """An optional score for the validation result.""" return self._score @property def thunk(self) -> ModelOutputThunk | None: + """The ModelOutputThunk associated with the validation func if an llm was used to generate the final result.""" return self._thunk def as_bool(self) -> bool: - """""" + """Return a boolean value based on the result.""" return self._result def __bool__(self) -> bool: + """Return a boolean value based on the result.""" return self.as_bool() diff --git a/mellea/stdlib/rewards/__init__.py b/mellea/stdlib/rewards/__init__.py index e69de29b..3c1aede3 100644 --- a/mellea/stdlib/rewards/__init__.py +++ b/mellea/stdlib/rewards/__init__.py @@ -0,0 +1 @@ +"""Components used with reward models.""" diff --git a/mellea/stdlib/rewards/prm_scorer.py b/mellea/stdlib/rewards/prm_scorer.py index 0c46dcbe..5653cd99 100644 --- a/mellea/stdlib/rewards/prm_scorer.py +++ b/mellea/stdlib/rewards/prm_scorer.py @@ -1,3 +1,5 @@ +"""PRM Requirements.""" + from mellea.backends.huggingface import HFProcessRewardModel from mellea.stdlib.base import CBlock, Context from mellea.stdlib.chat import Message @@ -10,11 +12,11 @@ class PRMScorer(ScorerRequirement): def __init__( self, *, prm_model: HFProcessRewardModel, preference_ordering: str = "max" ): - """ + """Instantiate a process reward model scorer based on local huggingface backend. Args: prm_model: The PRM model - preference_ordering: indicates whether the goal is to maximize or minimize the score. must be either "max" or "min" + preference_ordering: indicates whether the goal is to maximize or minimize the score. must be either "max" or "min". """ super().__init__( check_only=True, @@ -25,9 +27,7 @@ def __init__( self.model: HFProcessRewardModel = prm_model def _prm_validate(self, ctx: Context): - """ - Returns PRM score of last turn of context - """ + """Returns PRM score of last turn of context.""" last_turn = ctx.last_turn() assert last_turn is not None diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 1dff0d6d..3a6266b5 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -99,7 +99,6 @@ def _guardian_validate(self, ctx: Context): Returns: bool: True if there is no identified risk, False otherwise. """ - messages: list[dict[str, str]] = [] last_turn = ctx.last_turn() diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 12532f7b..5ca10a23 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -1,3 +1,5 @@ +"""Base Sampling Strategies.""" + import abc from copy import deepcopy @@ -47,8 +49,7 @@ def repair( past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], ) -> tuple[Component, Context]: - """ - Repair function that is being invoked if not all requirements are fulfilled. It should return a next action component. + """Repair function that is being invoked if not all requirements are fulfilled. It should return a next action component. Args: old_ctx: The context WITHOUT the last action + output. @@ -99,9 +100,13 @@ async def sample( Args: action : The action object to be sampled. context: The context to be passed to the sampling strategy. - show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + backend: The backend used for generating samples. requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. + format: output format for structured outputs. + model_options: model options to pass to the backend during generation / validation. + tool_calls: True if tool calls should be used during this sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -244,7 +249,16 @@ def select_from_failure( sampled_results: list[ModelOutputThunk], sampled_val: list[list[tuple[Requirement, ValidationResult]]], ) -> int: - # simply returns the first attempt if all loops fail + """Always returns the 0th index. + + Args: + sampled_actions: List of actions that have been executed (without success). + sampled_results: List of (unsuccessful) generation results for these actions. + sampled_val: List of validation results for the results. + + Returns: + The index of the result that should be selected as `.value`. + """ return 0 @staticmethod @@ -255,7 +269,18 @@ def repair( past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], ) -> tuple[Component, Context]: - # repeat the last action again. + """Always returns the unedited, last action. + + Args: + old_ctx: The context WITHOUT the last action + output. + new_ctx: The context including the last action + output. + past_actions: List of actions that have been executed (without success). + past_results: List of (unsuccessful) generation results for these actions. + past_val: List of validation results for the results. + + Returns: + The next action component and context to be used for the next generation attempt. + """ return past_actions[-1], old_ctx @@ -268,7 +293,16 @@ def select_from_failure( sampled_results: list[ModelOutputThunk], sampled_val: list[list[tuple[Requirement, ValidationResult]]], ) -> int: - # simply returns the first attempt if all loops fail + """Always returns the 0th index. + + Args: + sampled_actions: List of actions that have been executed (without success). + sampled_results: List of (unsuccessful) generation results for these actions. + sampled_val: List of validation results for the results. + + Returns: + The index of the result that should be selected as `.value`. + """ return 0 @staticmethod @@ -279,6 +313,18 @@ def repair( past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], ) -> tuple[Component, Context]: + """Adds a description of the requirements that failed to a copy of the original instruction. + + Args: + old_ctx: The context WITHOUT the last action + output. + new_ctx: The context including the last action + output. + past_actions: List of actions that have been executed (without success). + past_results: List of (unsuccessful) generation results for these actions. + past_val: List of validation results for the results. + + Returns: + The next action component and context to be used for the next generation attempt. + """ pa = past_actions[-1] if isinstance(pa, Instruction): last_failed_reqs: list[Requirement] = [ @@ -302,7 +348,16 @@ def select_from_failure( sampled_results: list[ModelOutputThunk], sampled_val: list[list[tuple[Requirement, ValidationResult]]], ): - # return the last assistant message even if all attempts of repair failed. + """Always returns the last index. The last message from the model will always be returned if all results are failures. + + Args: + sampled_actions: List of actions that have been executed (without success). + sampled_results: List of (unsuccessful) generation results for these actions. + sampled_val: List of validation results for the results. + + Returns: + The index of the result that should be selected as `.value`. + """ return -1 @staticmethod @@ -313,6 +368,18 @@ def repair( past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], ) -> tuple[Component, Context]: + """Returns a Message with a description of the failed requirements. + + Args: + old_ctx: The context WITHOUT the last action + output. + new_ctx: The context including the last action + output. + past_actions: List of actions that have been executed (without success). + past_results: List of (unsuccessful) generation results for these actions. + past_val: List of validation results for the results. + + Returns: + The next action component and context to be used for the next generation attempt. + """ assert isinstance(new_ctx, ChatContext), ( " Need chat context to run agentic sampling." ) diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index 53cfa83c..d0bdf341 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -1,3 +1,5 @@ +"""Best of N Sampling Strategy.""" + from copy import deepcopy import tqdm @@ -13,9 +15,7 @@ class BestofNSamplingStrategy(BaseSamplingStrategy): - """ - Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer - """ + """Sampling strategy that selects the best response from a set of samples as given by a Requirement Scorer.""" async def sample( self, @@ -35,9 +35,13 @@ async def sample( Args: action : The action object to be sampled. context: The context to be passed to the sampling strategy. - show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + backend: The backend used for generating samples. requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. + format: output format for structured outputs. + model_options: model options to pass to the backend during generation / validation. + tool_calls: True if tool calls should be used during this sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -236,8 +240,16 @@ def select_from_failure( sampled_results: list[ModelOutputThunk], sampled_val: list[list[tuple[Requirement, ValidationResult]]], ) -> int: - # select attempt with highest ScoreRequirementScore if all loops fail + """Selects the attempt with the highest score. + Args: + sampled_actions: List of actions that have been executed (without success). + sampled_results: List of (unsuccessful) generation results for these actions. + sampled_val: List of validation results for the results. + + Returns: + The index of the result that should be selected as `.value`. + """ scores: list[float | None] = [] for sample in sampled_val: @@ -258,6 +270,18 @@ def repair( past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], ) -> tuple[Component, Context]: + """Adds a description of the requirements that failed to a copy of the original instruction. + + Args: + old_ctx: The context WITHOUT the last action + output. + new_ctx: The context including the last action + output. + past_actions: List of actions that have been executed (without success). + past_results: List of (unsuccessful) generation results for these actions. + past_val: List of validation results for the results. + + Returns: + The next action component and context to be used for the next generation attempt. + """ pa = past_actions[-1] if isinstance(pa, Instruction): last_failed_reqs: list[Requirement] = [ diff --git a/mellea/stdlib/sampling/types.py b/mellea/stdlib/sampling/types.py index e5b81130..391b2c89 100644 --- a/mellea/stdlib/sampling/types.py +++ b/mellea/stdlib/sampling/types.py @@ -1,3 +1,5 @@ +"""Base types for sampling.""" + import abc from mellea.backends import Backend, BaseModelSubclass @@ -22,10 +24,12 @@ def __init__( """Initialize a new instance of sampling results. Args: - result: The final output or result from applying the sampling strategy. + result_index: The index of the final output or result from applying the sampling strategy. success: A boolean indicating whether the operation was successful. sample_generations: A list containing intermediate generations produced during the process. sample_validations: For each generation a list of tuples of a requirement and a validation result. + sample_actions: A list of intermediate actions used to produce sampling results. + sample_contexts: A list of contexts produced by the generation results. """ if sample_generations is None: sample_generations = [] @@ -53,18 +57,22 @@ def __init__( @property def result(self) -> ModelOutputThunk: + """The final output or result from applying the sampling strategy.""" return self.sample_generations[self.result_index] @property def result_ctx(self) -> Context: + """The context of the final output or result from applying the sampling strategy.""" return self.sample_contexts[self.result_index] @property def result_action(self) -> Component: + """The action that generated the final output or result from applying the sampling strategy.""" return self.sample_actions[self.result_index] @property def result_validations(self) -> list[tuple[Requirement, ValidationResult]]: + """The validation results associated with the final output or result from applying the sampling strategy.""" return self.sample_validations[self.result_index] @@ -88,13 +96,20 @@ async def sample( model_options: dict | None = None, tool_calls: bool = False, ) -> SamplingResult: - """This method is the abstract method for sampling a given instruction. + """This method is the abstract method for sampling a given component. It must be implemented by any concrete subclasses to provide specific sampling logic. Args: action : The action object to be sampled. context: The context to be passed to the sampling strategy. - requirements: The requirements to be used by the sampling strategy (merged with global requirements). + backend: The backend used for generating samples. + requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. + format: output format for structured outputs. + model_options: model options to pass to the backend during generation / validation. + tool_calls: True if tool calls should be used during this sampling strategy. + + Returns: + SamplingResult: A result object indicating the success or failure of the sampling process. """ diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index ef55c9f8..1b2f38c6 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -251,7 +251,7 @@ def act( tool_calls: bool = False, ) -> SamplingResult: ... - def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: + def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: # noqa: D417 """Runs a generic action, and adds both the action and the result to the context. Args: @@ -266,7 +266,6 @@ def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: Returns: A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - result, context = mfuncs.act( action, context=self.ctx, backend=self.backend, **kwargs ) @@ -311,7 +310,7 @@ def instruct( tool_calls: bool = False, ) -> SamplingResult: ... - def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingResult: + def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingResult: # noqa: D417 """Generates from an instruction. Args: @@ -329,7 +328,6 @@ def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingRes tool_calls: If true, tool calling is enabled. images: A list of images to be used in the instruction or None if none. """ - r = mfuncs.instruct( description, context=self.ctx, backend=self.backend, **kwargs ) @@ -354,7 +352,6 @@ def chat( tool_calls: bool = False, ) -> Message: """Sends a simple chat message and returns the response. Adds both messages to the Context.""" - result, context = mfuncs.chat( content=content, context=self.ctx, @@ -381,7 +378,6 @@ def validate( input: CBlock | None = None, ) -> list[ValidationResult]: """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - return mfuncs.validate( reqs=reqs, context=self.ctx, @@ -439,6 +435,8 @@ def transform( Args: obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. transformation: The string representing the query to be executed against the object. + format: format for output parsing; usually not needed with transform. + model_options: Model options to pass to the backend. Returns: ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, @@ -465,7 +463,6 @@ def last_prompt(self) -> str | list[dict] | None: Returns: A string if the last prompt was a raw call to the model OR a list of messages (as role-msg-dicts). Is None if none could be found. """ - op = self.ctx.last_output() if op is None: return None diff --git a/pyproject.toml b/pyproject.toml index 3bd6c83a..f2ca4269 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -115,7 +115,7 @@ select = [ # "B", # flake8-bugbear "C", # flake8-comprehensions "C9", # mccabe - # "D", # flake8-docstrings + "D", # flake8-docstrings "E", # pycodestyle errors (default) "F", # pyflakes (default) "I", # isort