diff --git a/docs/examples/aLora/101_example.py b/docs/examples/aLora/101_example.py index 92adaf13..1b65509f 100644 --- a/docs/examples/aLora/101_example.py +++ b/docs/examples/aLora/101_example.py @@ -1,10 +1,10 @@ import time -from mellea import LinearContext, MelleaSession +from mellea import MelleaSession from mellea.backends.aloras.huggingface.granite_aloras import HFConstraintAlora from mellea.backends.cache import SimpleLRUCache from mellea.backends.huggingface import LocalHFBackend -from mellea.stdlib.base import GenerateLog +from mellea.stdlib.base import ChatContext, GenerateLog from mellea.stdlib.requirement import ALoraRequirement, Requirement # Define a backend and add the constraint aLora @@ -22,7 +22,7 @@ backend.add_alora(custom_stembolt_failure_constraint) # Create M session -m = MelleaSession(backend, ctx=LinearContext()) +m = MelleaSession(backend, ctx=ChatContext()) # define a requirement failure_check = ALoraRequirement( diff --git a/docs/examples/agents/react.py b/docs/examples/agents/react.py index 80e612f9..9cb594f8 100644 --- a/docs/examples/agents/react.py +++ b/docs/examples/agents/react.py @@ -2,7 +2,7 @@ import inspect import json from collections.abc import Callable -from typing import Literal, Unpack +from typing import Literal import pydantic from jinja2 import Template @@ -13,6 +13,7 @@ import mellea.stdlib import mellea.stdlib.base import mellea.stdlib.chat +from mellea.stdlib.base import ChatContext react_system_template: Template = Template( """Answer the user's question as best you can. @@ -83,7 +84,7 @@ def call_tool(self, tool: ReactTool, kwargs_json: str): def tool_name_schema(self): names = self.tool_names() fields = dict() - fields["tool"] = Literal[Unpack[names]] + fields["tool"] = Literal[*names] return pydantic.create_model("ToolSelectionSchema", **fields) def get_tool_from_schema(self, content: str): @@ -103,7 +104,7 @@ def react( react_toolbox: ReactToolbox, ): assert m.ctx.is_chat_context, "ReACT requires a chat context." - test_ctx_lin = m.ctx.render_for_generation() + test_ctx_lin = m.ctx.view_for_generation() assert test_ctx_lin is not None and len(test_ctx_lin) == 0, ( "ReACT expects a fresh context." ) @@ -114,8 +115,9 @@ def react( ) # Add the system prompt and the goal to the chat history. - m.ctx.insert(mellea.stdlib.chat.Message(role="system", content=_sys_prompt)) - m.ctx.insert(mellea.stdlib.chat.Message(role="user", content=f"{goal}")) + m.ctx = m.ctx.add( + mellea.stdlib.chat.Message(role="system", content=_sys_prompt) + ).add(mellea.stdlib.chat.Message(role="user", content=f"{goal}")) # The main ReACT loop as a dynamic program: # ( ?(not done) ; @@ -156,7 +158,7 @@ def react( print("### Observation") tool_output = react_toolbox.call_tool(selected_tool, act_args.content) - m.ctx.insert(mellea.stdlib.chat.Message(role="tool", content=tool_output)) + m.ctx = m.ctx.add(mellea.stdlib.chat.Message(role="tool", content=tool_output)) print(tool_output) print("### Done Check") @@ -178,7 +180,7 @@ def react( if __name__ == "__main__": - m = mellea.start_session(ctx=mellea.stdlib.base.LinearContext()) + m = mellea.start_session(ctx=ChatContext()) def zip_lookup_tool_fn(city: str): """Returns the ZIP code for the `city`.""" diff --git a/docs/examples/agents/react_instruct.py b/docs/examples/agents/react_instruct.py index 69d49d42..5102c650 100644 --- a/docs/examples/agents/react_instruct.py +++ b/docs/examples/agents/react_instruct.py @@ -2,7 +2,7 @@ import inspect import json from collections.abc import Callable -from typing import Literal, Unpack +from typing import Literal import pydantic from jinja2 import Template @@ -11,6 +11,7 @@ import mellea.stdlib import mellea.stdlib.base import mellea.stdlib.chat +from mellea.stdlib.base import ChatContext react_system_template: Template = Template( """Answer the user's question as best you can. @@ -81,7 +82,7 @@ def call_tool(self, tool: ReactTool, kwargs_json: str): def tool_name_schema(self): names = self.tool_names() fields = dict() - fields["tool"] = Literal[Unpack[names]] + fields["tool"] = Literal[*names] return pydantic.create_model("ToolSelectionSchema", **fields) def get_tool_from_schema(self, content: str): @@ -101,7 +102,7 @@ def react( react_toolbox: ReactToolbox, ): assert m.ctx.is_chat_context, "ReACT requires a chat context." - test_ctx_lin = m.ctx.render_for_generation() + test_ctx_lin = m.ctx.view_for_generation() assert test_ctx_lin is not None and len(test_ctx_lin) == 0, ( "ReACT expects a fresh context." ) @@ -112,8 +113,9 @@ def react( ) # Add the system prompt and the goal to the chat history. - m.ctx.insert(mellea.stdlib.chat.Message(role="system", content=_sys_prompt)) - m.ctx.insert(mellea.stdlib.chat.Message(role="user", content=f"{goal}")) + m.ctx = m.ctx.add( + mellea.stdlib.chat.Message(role="system", content=_sys_prompt) + ).add(mellea.stdlib.chat.Message(role="user", content=f"{goal}")) # The main ReACT loop as a dynamic program: # ( ?(not done) ; @@ -159,7 +161,7 @@ def react( print("### Observation") tool_output = react_toolbox.call_tool(selected_tool, act_args_val) - m.ctx.insert(mellea.stdlib.chat.Message(role="tool", content=tool_output)) + m.ctx = m.ctx.add(mellea.stdlib.chat.Message(role="tool", content=tool_output)) print(tool_output) print("### Done Check") @@ -187,7 +189,7 @@ def react( if __name__ == "__main__": - m = mellea.start_session(ctx=mellea.stdlib.base.LinearContext()) + m = mellea.start_session(ctx=ChatContext()) def zip_lookup_tool_fn(city: str): """Returns the ZIP code for the `city`.""" diff --git a/docs/examples/generative_slots/generate_with_context.py b/docs/examples/generative_slots/generate_with_context.py index e0d13ef3..98050523 100644 --- a/docs/examples/generative_slots/generate_with_context.py +++ b/docs/examples/generative_slots/generate_with_context.py @@ -1,6 +1,6 @@ -from mellea import LinearContext, generative, start_session +from mellea import generative, start_session from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock +from mellea.stdlib.base import CBlock, ChatContext # Generative slots can be used with sessions that have context. # By utilizing context, you can change the results of several @@ -34,7 +34,7 @@ def give_feedback(essay: str) -> list[str]: if __name__ == "__main__": m = start_session( - ctx=LinearContext(), model_options={ModelOption.MAX_NEW_TOKENS: 100} + ctx=ChatContext(), model_options={ModelOption.MAX_NEW_TOKENS: 100} ) text = """ @@ -55,7 +55,7 @@ def give_feedback(essay: str) -> list[str]: # If you have a set of generative functions, you can tweak them all by # adding context to the session they are running in. - m.ctx.insert( + m.ctx = m.ctx.add( CBlock( "You are an elementary school teacher. " "Any grades and feedback that you give should keep that in mind. Remember to be " @@ -74,7 +74,7 @@ def give_feedback(essay: str) -> list[str]: # And, let's reset the context and try a different grading style. m.reset() - m.ctx.insert( + m.ctx = m.ctx.add( CBlock( "You are a grammarian that is focused solely on spelling and syntax, " "not on the content of essays. When giving grades and feedback, focus " diff --git a/docs/examples/helper/__init__.py b/docs/examples/helper/__init__.py index 847354bc..22104405 100644 --- a/docs/examples/helper/__init__.py +++ b/docs/examples/helper/__init__.py @@ -1 +1 @@ -from .helpers import Any, fill, w +from .helpers import req_print, w diff --git a/docs/examples/helper/helpers.py b/docs/examples/helper/helpers.py index e3202533..f6b412f4 100644 --- a/docs/examples/helper/helpers.py +++ b/docs/examples/helper/helpers.py @@ -1,7 +1,14 @@ from textwrap import fill from typing import Any +from mellea.stdlib.requirement import Requirement, ValidationResult + # Just for printing stuff nicely... def w(x: Any) -> str: return fill(str(x), width=120, replace_whitespace=False) + + +def req_print(rv_list: list[tuple[Requirement, ValidationResult]]) -> str: + parts = [f"{bool(rv[1])}\t: {rv[0].description}" for rv in rv_list] + return "\n".join(parts) diff --git a/docs/examples/image_text_models/vision_ollama_chat.py b/docs/examples/image_text_models/vision_ollama_chat.py index feee7334..f2552636 100644 --- a/docs/examples/image_text_models/vision_ollama_chat.py +++ b/docs/examples/image_text_models/vision_ollama_chat.py @@ -2,11 +2,11 @@ from PIL import Image -from mellea import LinearContext, start_session -from mellea.stdlib.base import ImageBlock +from mellea import start_session +from mellea.stdlib.base import ChatContext, ImageBlock -m = start_session(model_id="granite3.2-vision", ctx=LinearContext()) -# m = start_session(model_id="llava", ctx=LinearContext()) +m = start_session(model_id="granite3.2-vision", ctx=ChatContext()) +# m = start_session(model_id="llava", ctx=ChatContext()) # load image test_img = Image.open("pointing_up.jpg") diff --git a/docs/examples/instruct_validate_repair/101_email.py b/docs/examples/instruct_validate_repair/101_email.py index 4097c3fc..ca7bd417 100644 --- a/docs/examples/instruct_validate_repair/101_email.py +++ b/docs/examples/instruct_validate_repair/101_email.py @@ -1,30 +1,19 @@ # This is the 101 example for using `session` and `instruct`. # helper function to wrap text from docs.examples.helper import w -from mellea import instruct, start_session +from mellea import start_session from mellea.backends.types import ModelOption -# create a session using Granite 3.3 8B on Ollama and a simple context [see below] -with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}): - # write an email - email_v1 = instruct("Write an email to invite all interns to the office party.") - with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m: # write an email email_v1 = m.instruct("Write an email to invite all interns to the office party.") + print(m.last_prompt()) # print result print(f"***** email ****\n{w(email_v1)}\n*******") # ************** END ************* - -# # optionally: print the debug log for the last instruction on the context -# from mellea.stdlib.base import GenerateLog -# _, log = m.ctx.last_output_and_logs() -# if isinstance(log, GenerateLog): # should be -# print(f"Prompt:\n{w(log.prompt)}") # print prompt - # # start_session() is equivalent to: # from mellea.backends import model_ids # from mellea.backends.ollama import OllamaModelBackend diff --git a/docs/examples/instruct_validate_repair/101_email_with_requirements.py b/docs/examples/instruct_validate_repair/101_email_with_requirements.py index 0ddaffa6..5be7fdc3 100644 --- a/docs/examples/instruct_validate_repair/101_email_with_requirements.py +++ b/docs/examples/instruct_validate_repair/101_email_with_requirements.py @@ -5,7 +5,7 @@ # create a session using Granite 3.3 8B on Ollama and a simple context [see below] m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) -# write an email +# write an email with automatic requirement checking. email_v1 = m.instruct( "Write an email to invite all interns to the office party.", requirements=["be formal", "Use 'Dear interns' as greeting."], diff --git a/docs/examples/instruct_validate_repair/101_email_with_validate.py b/docs/examples/instruct_validate_repair/101_email_with_validate.py index 72e0ad92..efa232e2 100644 --- a/docs/examples/instruct_validate_repair/101_email_with_validate.py +++ b/docs/examples/instruct_validate_repair/101_email_with_validate.py @@ -1,4 +1,4 @@ -from docs.examples.helper import w +from docs.examples.helper import req_print, w from mellea import start_session from mellea.backends.types import ModelOption from mellea.stdlib.sampling import RejectionSamplingStrategy @@ -6,14 +6,24 @@ # create a session using Granite 3.3 8B on Ollama and a simple context [see below] m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) -email_v1 = m.instruct( +email_v2_samples = m.instruct( "Write an email to invite all interns to the office party.", requirements=["be formal", "Use 'Dear interns' as greeting."], strategy=RejectionSamplingStrategy(loop_budget=3), + return_sampling_results=True, ) -# print result -print(f"***** email ****\n{w(email_v1)}\n*******") +if email_v2_samples.success: + print(f"Success: \n{w(email_v2_samples.result)}") + print( + f"===> Requirement for this sample: \n{req_print(email_v2_samples.sample_validations[-1])}" + ) +else: + print(f"Failure: \n{w(email_v2_samples.result)}") + selected_index = email_v2_samples.sample_generations.index(email_v2_samples.result) + print( + f"===> Requirement for this sample: \n{req_print(email_v2_samples.sample_validations[selected_index])}" + ) # # [optional] get logs for all loops: # from mellea.stdlib.base import GenerateLog diff --git a/docs/examples/mify/rich_document_advanced.py b/docs/examples/mify/rich_document_advanced.py index 186ddd90..9ee0caed 100644 --- a/docs/examples/mify/rich_document_advanced.py +++ b/docs/examples/mify/rich_document_advanced.py @@ -9,7 +9,7 @@ from mellea.stdlib.base import ModelOutputThunk, TemplateRepresentation # Use a `SimpleContext` so that each LLM call is independent. -m = mellea.start_session(backend_name="hf", ctx=mellea.SimpleContext()) +m = mellea.start_session(backend_name="hf") # 2. Let's import docling so that we can process pdf documents. diff --git a/docs/examples/notebooks/context_example.ipynb b/docs/examples/notebooks/context_example.ipynb index 38af6c5d..7392df6f 100644 --- a/docs/examples/notebooks/context_example.ipynb +++ b/docs/examples/notebooks/context_example.ipynb @@ -71,7 +71,7 @@ "source": [ "## Import Mellea and Start a Session with LinearContext\n", "\n", - "Up to this point we have used SimpleContext, a context manager that resets the chat message history on each model call. That is, the model's context is entirely determined by the current Component. \n", + "Up to this point we have used SimpleContext, a context manager that resets the chat message history on each model call. That is, the model's context is entirely determined by the current Component.\n", "\n", "Mellea also provides a LinearContext, which behaves like a chat history. We will use the LinearContext to interact with cat hmodels:" ] @@ -84,9 +84,10 @@ }, "outputs": [], "source": [ - "from mellea import LinearContext, start_session\n", + "from mellea import start_session\n", + "from mellea.stdlib.base import ChatContext\n", "\n", - "m = start_session(ctx=LinearContext())\n", + "m = start_session(ctx=ChatContext())\n", "m.chat(\"Make up a math problem.\")\n", "m.chat(\"Solve your math problem.\")\n", "print(m.ctx.last_output())\n", diff --git a/docs/examples/notebooks/m_serve_example.ipynb b/docs/examples/notebooks/m_serve_example.ipynb index 00518460..7fa6e6b0 100644 --- a/docs/examples/notebooks/m_serve_example.ipynb +++ b/docs/examples/notebooks/m_serve_example.ipynb @@ -83,11 +83,11 @@ "\n", "import mellea\n", "from cli.serve.models import ChatMessage\n", - "from mellea.stdlib.base import LinearContext, ModelOutputThunk\n", + "from mellea.stdlib.base import ChatContext, ModelOutputThunk\n", "from mellea.stdlib.requirement import Requirement, simple_validate\n", "from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult\n", "\n", - "session = mellea.start_session(ctx=LinearContext())\n", + "session = mellea.start_session(ctx=ChatContext())\n", "\n", "\n", "def validate_hi_bob(email: str) -> bool:\n", diff --git a/docs/examples/safety.py/guardian.py b/docs/examples/safety.py/guardian.py index 9293ce63..989b0777 100644 --- a/docs/examples/safety.py/guardian.py +++ b/docs/examples/safety.py/guardian.py @@ -2,9 +2,8 @@ from mellea import MelleaSession from mellea.backends import model_ids -from mellea.backends.dummy import DummyBackend from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import Context, ContextTurn, ModelOutputThunk, SimpleContext +from mellea.stdlib.base import ContextTurn, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk @@ -25,10 +24,9 @@ print("\n Test 2\n") # create a mean conversation and add to context -m.ctx.insert_turn( - ContextTurn(Message("user", "Hello. "), ModelOutputThunk("You are very ugly.")) +m.ctx = m.ctx.add(Message("user", "Hello. ")).add( + ModelOutputThunk("You are very ugly.") ) - # show last turn in chat print(f"Context: {m.ctx.last_turn()}") diff --git a/docs/examples/sessions/creating_a_new_type_of_session.py b/docs/examples/sessions/creating_a_new_type_of_session.py index 3cc34245..14dd98ed 100644 --- a/docs/examples/sessions/creating_a_new_type_of_session.py +++ b/docs/examples/sessions/creating_a_new_type_of_session.py @@ -3,7 +3,7 @@ from mellea import MelleaSession from mellea.backends import Backend, BaseModelSubclass from mellea.backends.ollama import OllamaModelBackend -from mellea.stdlib.base import CBlock, Context, LinearContext, ModelOutputThunk +from mellea.stdlib.base import CBlock, ChatContext, Context, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.requirement import Requirement, reqify from mellea.stdlib.safety.guardian import GuardianCheck, GuardianRisk @@ -66,7 +66,7 @@ def chat( m = ChatCheckingSession( requirements=[GuardianCheck("jailbreak"), GuardianCheck("profanity")], backend=OllamaModelBackend(), - ctx=LinearContext(), + ctx=ChatContext(), ) # You can run this code to see the immediate checks working. diff --git a/docs/examples/tutorial/context_example.py b/docs/examples/tutorial/context_example.py index f29e661d..e98e1182 100644 --- a/docs/examples/tutorial/context_example.py +++ b/docs/examples/tutorial/context_example.py @@ -1,9 +1,10 @@ -from mellea import LinearContext, start_session +from mellea import start_session +from mellea.stdlib.base import ChatContext -m = start_session(ctx=LinearContext()) +m = start_session(ctx=ChatContext()) m.chat("Make up a math problem.") m.chat("Solve your math problem.") print(m.ctx.last_output()) - +print("==================") print(m.ctx.last_turn()) diff --git a/docs/examples/tutorial/m_serve_example.py b/docs/examples/tutorial/m_serve_example.py index 66005dd3..79bf6e68 100644 --- a/docs/examples/tutorial/m_serve_example.py +++ b/docs/examples/tutorial/m_serve_example.py @@ -4,11 +4,11 @@ import mellea from cli.serve.models import ChatMessage -from mellea.stdlib.base import LinearContext, ModelOutputThunk +from mellea.stdlib.base import ChatContext, ModelOutputThunk from mellea.stdlib.requirement import Requirement, simple_validate from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult -session = mellea.start_session(ctx=LinearContext()) +session = mellea.start_session(ctx=ChatContext()) def validate_hi_bob(email: str) -> bool: diff --git a/docs/examples/tutorial/simple_email.py b/docs/examples/tutorial/simple_email.py index 8acd9713..45f5bb11 100644 --- a/docs/examples/tutorial/simple_email.py +++ b/docs/examples/tutorial/simple_email.py @@ -26,7 +26,6 @@ def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str: ) ) - print("Email with requirements:") @@ -52,7 +51,6 @@ def write_email_with_requirements( ) ) - print("Email with rejection sampling:") from mellea.stdlib.sampling import RejectionSamplingStrategy # noqa: E402 diff --git a/docs/tutorial.md b/docs/tutorial.md index 6ff7f924..7341eae6 100644 --- a/docs/tutorial.md +++ b/docs/tutorial.md @@ -841,7 +841,7 @@ In the above arguments, `path_or_model_id` refers to the model checkpoint from l We are now ready to create a M session, define the requirement, and run the instruction: ```python -m = MelleaSession(backend, ctx=LinearContext()) +m = MelleaSession(backend, ctx=ChatContext()) failure_check = req("The failure mode should not be none.") res = m.instruct("Write triage summaries based on technician note.", requirements=[failure_check]) ``` @@ -918,13 +918,13 @@ m = mellea.MelleaSession( ) ``` -The `SimpleContext` -- which is the only context we have used so far -- is a context manager that resets the chat message history on each model call. That is, the model's context is entirely determined by the current Component. Mellea also provides a `LinearContext`, which behaves like a chat history. We can use the LinearContext to interact with chat models: +The `SimpleContext` -- which is the only context we have used so far -- is a context manager that resets the chat message history on each model call. That is, the model's context is entirely determined by the current Component. Mellea also provides a `ChatContext`, which behaves like a chat history. We can use the ChatContext to interact with chat models: ```python # file: https://github.com/generative-computing/mellea/blob/main/docs/examples/tutorial/context_example.py#L1-L5 from mellea import start_session -m = mellea.start_session(ctx=LinearContext()) +m = mellea.start_session(ctx=ChatContext()) m.chat("Make up a math problem.") m.chat("Solve your math problem.") ``` diff --git a/mellea/__init__.py b/mellea/__init__.py index bbd90e81..a8fc24fa 100644 --- a/mellea/__init__.py +++ b/mellea/__init__.py @@ -1,28 +1,7 @@ """Mellea is a library for building robust LLM applications.""" import mellea.backends.model_ids as model_ids -from mellea.stdlib.base import LinearContext, SimpleContext from mellea.stdlib.genslot import generative -from mellea.stdlib.session import ( - MelleaSession, - chat, - instruct, - query, - start_session, - transform, - validate, -) +from mellea.stdlib.session import MelleaSession, start_session -__all__ = [ - "LinearContext", - "MelleaSession", - "SimpleContext", - "chat", - "generative", - "instruct", - "model_ids", - "query", - "start_session", - "transform", - "validate", -] +__all__ = ["MelleaSession", "generative", "model_ids", "start_session"] diff --git a/mellea/backends/__init__.py b/mellea/backends/__init__.py index 5711aa73..fd76bc50 100644 --- a/mellea/backends/__init__.py +++ b/mellea/backends/__init__.py @@ -42,7 +42,7 @@ def generate_from_context( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: # i.e., ContextDiff + ) -> tuple[ModelOutputThunk, Context]: """Generates a model output from a context. May not mutate the context. This must be called from a running event loop as it creates a task to run the generation request. Args: @@ -50,8 +50,10 @@ def generate_from_context( ctx: The rest of the context. format: A response format to used for structured outputs / constrained decoding. model_options: Any model options to upsert into the defaults for this call. - generate_logs: a `GenerateLog` instance to add log information to. tool_calls: If `True`, then tool calls are extracts from the `action` `Component`. Assumption: if tool_calls is enabled, then the action `Component` has a TemplateRepresentation + + Returns: + a tuple of (ModelOutputThunk, Context) where the Context is the new context after the generation has been completed. """ ... diff --git a/mellea/backends/dummy.py b/mellea/backends/dummy.py index c2216c6e..bde21d8b 100644 --- a/mellea/backends/dummy.py +++ b/mellea/backends/dummy.py @@ -1,7 +1,7 @@ """This module holds shim backends used for smoke tests.""" from mellea.backends import Backend, BaseModelSubclass -from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk +from mellea.stdlib.base import CBlock, Component, Context, ModelOutputThunk class DummyBackend(Backend): @@ -24,15 +24,16 @@ def generate_from_context( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: + ) -> tuple[ModelOutputThunk, Context]: """See constructor for an exmplanation of how DummyBackends work.""" assert format is None, "The DummyBackend does not support constrained decoding." if self.responses is None: - return ModelOutputThunk(value="dummy") + mot = ModelOutputThunk(value="dummy") + return mot, ctx.add(action).add(mot) elif self.idx < len(self.responses): return_value = ModelOutputThunk(value=self.responses[self.idx]) self.idx += 1 - return return_value + return return_value, ctx.add(action).add(return_value) else: raise Exception( f"DummyBackend expected no more than {len(self.responses)} calls." diff --git a/mellea/backends/formatter.py b/mellea/backends/formatter.py index 0b296eff..0a83163c 100644 --- a/mellea/backends/formatter.py +++ b/mellea/backends/formatter.py @@ -17,14 +17,10 @@ from mellea.stdlib.base import ( CBlock, Component, - Context, - LinearContext, ModelOutputThunk, - SimpleContext, TemplateRepresentation, ) from mellea.stdlib.chat import Message, ToolMessage -from mellea.stdlib.mobject import Query, Transform class Formatter(abc.ABC): @@ -35,11 +31,6 @@ def print(self, c: Component | CBlock) -> str: """Renders a component for input to a model.""" ... - @abc.abstractmethod - def print_context(self, ctx: Context) -> str: - """Renders a Context for input to a model.""" - ... - @abc.abstractmethod def parse( self, source_component: Component | CBlock, result: ModelOutputThunk @@ -170,23 +161,6 @@ def _parse( else: return result - def print_context(self, ctx: Context) -> str: - """Renders a Context for input to a model.""" - assert not ctx.is_chat_context, ( - "Chat contexts should be handled in a backend by first using `Formatter.to_chat_messages` and then passing the dict to an API endpoint or using hf.apply_chat_template." - ) - match ctx: - case LinearContext(): - linearized_ctx = ctx.render_for_generation() - assert linearized_ctx is not None - return "".join([self.print(x) for x in linearized_ctx]) - case SimpleContext(): - raise Exception("Do not know how to handle a SimpleContext yet.") - case _: - raise Exception( - f"TemplateFormatter does not know how to print a {ctx.__class__.__name__} context." - ) - def _stringify( self, c: ( diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index ef50eb40..ae8bb249 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -191,7 +191,7 @@ def generate_from_context( format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, tool_calls: bool = False, - ) -> ModelOutputThunk: + ): """Generate using the huggingface model.""" # Upsert model options. model_opts = self._simplify_and_merge(model_options) @@ -208,12 +208,14 @@ def generate_from_context( if issubclass(type(action), ALoraRequirement): reroute_to_alora = True if reroute_to_alora: - return self._generate_from_context_alora( + mot = self._generate_from_context_alora( action, ctx, format=format, model_options=model_opts ) - return self._generate_from_context_standard( + return mot, ctx.add(mot) + mot = self._generate_from_context_standard( action, ctx, format=format, model_options=model_opts, tool_calls=tool_calls ) + return mot, ctx.add(action).add(mot) def _generate_from_context_alora( self, @@ -236,7 +238,7 @@ def _generate_from_context_alora( "This code block should not execute unless there is a 'constraint' alora loaded." ) # Construct the linearized context. This is very similar to normal generation. - linearized_ctx = ctx.render_for_generation() + linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None and len(linearized_ctx) > 1 msgs = self.formatter.to_chat_messages(linearized_ctx) user_message, assistant_message = msgs[-2].content, msgs[-1].content @@ -275,7 +277,7 @@ def _generate_from_context_standard( # If the Context is a ChatHistory then we will pretty-print each content as a message and then use apply_chat_template. # Otherwise, we will linearize the context and treat it as a raw input. if ctx.is_chat_context: - linearized_ctx = ctx.render_for_generation() + linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None, ( "If ctx.is_chat_context, then the context should be linearizable." ) diff --git a/mellea/backends/litellm.py b/mellea/backends/litellm.py index 23ea446e..93a1ef5d 100644 --- a/mellea/backends/litellm.py +++ b/mellea/backends/litellm.py @@ -116,13 +116,14 @@ def generate_from_context( assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." ) - return self._generate_from_chat_context_standard( + mot = self._generate_from_chat_context_standard( action, ctx, format=format, model_options=model_options, tool_calls=tool_calls, ) + return mot, ctx.add(action).add(mot) def _simplify_and_merge( self, model_options: dict[str, Any] | None @@ -216,7 +217,7 @@ def _generate_from_chat_context_standard( tool_calls: bool = False, ) -> ModelOutputThunk: model_opts = self._simplify_and_merge(model_options) - linearized_context = ctx.render_for_generation() + linearized_context = ctx.view_for_generation() assert linearized_context is not None, ( "Cannot generate from a non-linear context in a FormatterBackend." ) diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index a4fe1324..9c16d0d6 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -28,7 +28,6 @@ GenerateType, ModelOutputThunk, ModelToolCall, - TemplateRepresentation, ) from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement @@ -67,8 +66,9 @@ def __init__( self._get_ollama_model_id() # Setup the client and ensure that we have the model available. + self._base_url = base_url self._client = ollama.Client(base_url) - self._async_client = ollama.AsyncClient(base_url) + if not self._check_ollama_server(): err = f"could not create OllamaModelBackend: ollama server not running at {base_url}" FancyLogger.get_logger().error(err) @@ -242,7 +242,7 @@ def generate_from_context( assert ctx.is_chat_context, ( "The ollama backend only supports chat-like contexts." ) - return self.generate_from_chat_context( + mot = self.generate_from_chat_context( action, ctx, format=format, @@ -250,6 +250,8 @@ def generate_from_context( tool_calls=tool_calls, ) + return mot, ctx.add(action).add(mot) + def generate_from_chat_context( self, action: Component | CBlock, @@ -271,7 +273,7 @@ def generate_from_chat_context( """ model_opts = self._simplify_and_merge(model_options) - linearized_context = ctx.render_for_generation() + linearized_context = ctx.view_for_generation() assert linearized_context is not None, ( "Cannot generate from a non-linear context in a FormatterBackend." ) @@ -316,10 +318,13 @@ def generate_from_chat_context( add_tools_from_context_actions(tools, [action]) FancyLogger.get_logger().info(f"Tools for call: {tools.keys()}") + # Ollama ties its async client to an event loop so we have to create it here. + async_client = ollama.AsyncClient(self._base_url) + # Generate a chat response from ollama, using the chat messages. Can be either type since stream is passed as a model option. chat_response: Coroutine[ Any, Any, AsyncIterator[ollama.ChatResponse] | ollama.ChatResponse - ] = self._async_client.chat( + ] = async_client.chat( model=self._get_ollama_model_id(), messages=conversation, tools=list(tools.values()), @@ -381,10 +386,11 @@ async def get_response(coroutines): responses = await asyncio.gather(*coroutines, return_exceptions=True) return responses + async_client = ollama.AsyncClient(self._base_url) # Run async so that we can make use of Ollama's concurrency. coroutines = [] for prompt in prompts: - co = self._async_client.generate( + co = async_client.generate( model=self._get_ollama_model_id(), prompt=prompt, raw=True, diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 11646ad9..2164eafd 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -42,7 +42,6 @@ GenerateLog, GenerateType, ModelOutputThunk, - ModelToolCall, ) from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement @@ -277,13 +276,14 @@ def generate_from_context( assert ctx.is_chat_context, NotImplementedError( "The Openai backend only supports chat-like contexts." ) - return self.generate_from_chat_context( + mot = self.generate_from_chat_context( action, ctx, format=format, model_options=model_options, tool_calls=tool_calls, ) + return mot, ctx.add(action).add(mot) def generate_from_chat_context( self, @@ -342,7 +342,7 @@ def _generate_from_chat_context_alora( ) # Construct the linearized context. This is very similar to normal generation. - linearized_ctx = ctx.render_for_generation() + linearized_ctx = ctx.view_for_generation() assert linearized_ctx is not None and len(linearized_ctx) > 1 msgs = self.formatter.to_chat_messages(linearized_ctx) user_message, assistant_message = msgs[-2].content, msgs[-1].content @@ -417,7 +417,7 @@ def _generate_from_chat_context_standard( model_opts = self._simplify_and_merge( model_options, is_chat_context=ctx.is_chat_context ) - linearized_context = ctx.render_for_generation() + linearized_context = ctx.view_for_generation() assert linearized_context is not None, ( "Cannot generate from a non-linear context in a FormatterBackend." ) diff --git a/mellea/backends/watsonx.py b/mellea/backends/watsonx.py index e9400748..155773aa 100644 --- a/mellea/backends/watsonx.py +++ b/mellea/backends/watsonx.py @@ -92,14 +92,14 @@ def __init__( if project_id is None: project_id = os.environ.get("WATSONX_PROJECT_ID") - _creds = Credentials(url=base_url, api_key=api_key) - _client = APIClient(credentials=_creds) - self._model = ModelInference( + self._creds = Credentials(url=base_url, api_key=api_key) + _client = APIClient(credentials=self._creds) + self._model_inference = ModelInference( model_id=self._get_watsonx_model_id(), api_client=_client, - credentials=_creds, + credentials=self._creds, project_id=project_id, - params=model_options, + params=self.model_options, **kwargs, ) @@ -132,6 +132,12 @@ def __init__( ModelOption.MAX_NEW_TOKENS: "max_new_tokens", } + @property + def _model(self) -> ModelInference: + """Watsonx's client gets tied to a specific event loop. Reset it here.""" + self._model_inference.set_api_client(APIClient(self._creds)) + return self._model_inference + def _get_watsonx_model_id(self) -> str: """Gets the watsonx model id from the model_id that was provided in the constructor. Raises AssertionError if the ModelIdentifier does not provide a watsonx_name.""" watsonx_model_id = ( @@ -218,13 +224,14 @@ def generate_from_context( assert ctx.is_chat_context, NotImplementedError( "The watsonx.ai backend only supports chat-like contexts." ) - return self.generate_from_chat_context( + mot = self.generate_from_chat_context( action, ctx, format=format, model_options=model_options, tool_calls=tool_calls, ) + return mot, ctx.add(action).add(mot) def generate_from_chat_context( self, @@ -241,7 +248,7 @@ def generate_from_chat_context( model_options, is_chat_context=ctx.is_chat_context ) - linearized_context = ctx.render_for_generation() + linearized_context = ctx.view_for_generation() assert linearized_context is not None, ( "Cannot generate from a non-linear context in a FormatterBackend." ) diff --git a/mellea/helpers/async_helpers.py b/mellea/helpers/async_helpers.py index d3ed2744..0d3e1866 100644 --- a/mellea/helpers/async_helpers.py +++ b/mellea/helpers/async_helpers.py @@ -2,6 +2,8 @@ from collections.abc import AsyncIterator, Coroutine from typing import Any +from mellea.stdlib.base import ModelOutputThunk + async def send_to_queue( co: Coroutine[Any, Any, AsyncIterator | Any] | AsyncIterator, aqueue: asyncio.Queue @@ -29,3 +31,15 @@ async def send_to_queue( # them to the queue. except Exception as e: await aqueue.put(e) + + +async def wait_for_all_mots(mots: list[ModelOutputThunk]): + """Helper function to make waiting for multiple ModelOutputThunks to be computed easier. + + All ModelOutputThunks must be from the same event loop. This should always be the case in sampling + functions, session functions, and top-level mellea functions.""" + coroutines: list[Coroutine[Any, Any, str]] = [] + for mot in mots: + coroutines.append(mot.avalue()) + + await asyncio.gather(*coroutines) diff --git a/mellea/stdlib/base.py b/mellea/stdlib/base.py index d7ad307c..dc09c23c 100644 --- a/mellea/stdlib/base.py +++ b/mellea/stdlib/base.py @@ -12,7 +12,7 @@ from copy import copy, deepcopy from dataclasses import dataclass from io import BytesIO -from typing import Any, Protocol, runtime_checkable +from typing import Any, Protocol, TypeVar, runtime_checkable from PIL import Image as PILImage @@ -344,203 +344,138 @@ class ContextTurn: output: ModelOutputThunk | None -class Context(abc.ABC): - """A `Context` is used to track the state of a `MelleaSession`.""" +ContextT = TypeVar("ContextT", bound="Context") - is_chat_context: bool = False - @abc.abstractmethod - def reset(self): - """Resets the context to a fresh state. +class Context(abc.ABC): + """A `Context` is used to track the state of a `MelleaSession`. - Note: resetting a context does NOT free memory or clear cache. For this reason, you probably want to be calling this method from a `Session`. - """ - ... + A context is immutable. Every alteration leads to a new context. + """ - @abc.abstractmethod - def insert( - self, - value: CBlock | Component, - *, - key: Any | None = None, - generate_logs: list[GenerateLog] | None = None, - ): - """Each `Context` must define its own semantics for inserting something into the context. + _previous: Context | None + _data: Component | CBlock | None + _is_root: bool + _is_chat_context: bool = True - Args: - value (CBlock | Component): the thing to insert. - key (Optional[Any]): a key by which the value is indexed to. This is optional and only needed for fairly sophisticated Context types. Note that this is NOT necessarily a key that can be used for KV cache lookups! - generate_logs: Adding log information about the insertion. Should only be used for output objects. - """ - ... + def __init__(self): + """Constructs a new root context with no content.""" + self._previous = None + self._data = None + self._is_root = True - @abc.abstractmethod - def insert_turn( - self, turn: ContextTurn, *, generate_logs: list[GenerateLog] | None = None - ): - """Insert a turn into the chat history. + # factory functions below this line. - Args: - turn: the turn to insert. - generate_logs: Adding log information about the insertion. Will be bound to the output part of the turn. + @classmethod + def from_previous( + cls: type[ContextT], previous: Context, data: Component | CBlock + ) -> ContextT: + """Constructs a new context from an existing context.""" - Returns: - None - """ - ... + assert isinstance(previous, Context), ( + "Cannot create a new context from a non-Context object." + ) + assert data is not None, "Cannot create a new context from None data." - @abc.abstractmethod - def copy(self) -> Context: - """Produces a copy of the current Context's contents, allowing for branch-and-merge style semantics over a Context. + x = cls() + x._previous = previous + x._data = data + x._is_root = False + x._is_chat_context = previous._is_chat_context + return x - Implementations should not copy the actual objects in the context but retain a reference to them.""" - ... + @classmethod + def reset_to_new(cls: type[ContextT]) -> ContextT: + """Returns an empty context for convenience.""" + return cls() - @abc.abstractmethod - def _hash_for_kv_cache(self): - """A `Context` is responsible for maintaining a hash representation of itself. This hash is used by backends to refer to a Context's state.""" - ... + # Internal functions below this line. - @abc.abstractmethod - def render_for_generation(self) -> list[Component | CBlock] | None: - """Provides a linear list of context components to use for generation, or None if that is not possible to construct.""" - ... + @property + def is_root_node(self) -> bool: + """Returns whether this context is the root context node.""" + return self._is_root - @abc.abstractmethod - def actions_for_available_tools(self) -> list[Component | CBlock] | None: - """Provides a list of actions to extract tools from for use with during generation, or None if that's not possible. + @property + def previous_node(self) -> Context | None: + """Returns the context node from which this context node was created. - Can be used to make the available tools differ from the tools of all the actions in the context. + Internal use: Users should not need to use this property. """ - ... + return self._previous - @abc.abstractmethod - def full_event_log(self) -> list[Component | CBlock]: - """Provides a list of all events stored in the context.""" - ... - - @abc.abstractmethod - def last_output(self) -> ModelOutputThunk | None: - """The last output thunk of the context.""" - ... + @property + def node_data(self) -> Component | CBlock | None: + """Returns the data associated with this context node. - @abc.abstractmethod - def last_turn(self) -> ContextTurn | None: - """The last input/output turn of the context.""" - ... + Internal use: Users should not need to use this property. + """ + return self._data @property - @abc.abstractmethod - def logs(self) -> list[list[GenerateLog] | None]: - """Returns a list of all logs in the context.""" - ... + def is_chat_context(self) -> bool: + """Returns whether this context is a chat context.""" + return self._is_chat_context - @abc.abstractmethod - def get_logs_by_index(self, index: int) -> list[GenerateLog] | None: - """Returns a `GenerateLog` for the given index.""" - ... - - @abc.abstractmethod - def last_output_and_logs( - self, all_intermediate_results: bool = False - ) -> tuple[ModelOutputThunk | None, list[GenerateLog] | None | GenerateLog]: - """Returns a `ModelOutputThunk` for the last output and the corresponding `GenerateLog`. + # User functions below this line. - Args: - all_intermediate_results: if False (default), only returns the Log for the that led to the final output, if True, a list of all intermediate results (including the final one) is returned. - """ - ... + def as_list(self, last_n_components: int | None = None) -> list[Component | CBlock]: + """Returns a list of the last n components in the context sorted from FIRST TO LAST. + If `last_n_components` is `None`, then all components are returned.""" + context_list: list[Component | CBlock] = [] + current_context: Context = self -class BasicContext(Context, abc.ABC): - """Implementing some common functionality for Contexts.""" + last_n_count = 0 + while not current_context.is_root_node and ( + last_n_components is None or last_n_count < last_n_components + ): + data = current_context.node_data + assert data is not None, "Data cannot be None (except for root context)." + assert data not in context_list, ( + "There might be a cycle in the context tree. That is not allowed." + ) + context_list.append(data) + last_n_count += 1 - _ctx: list[CBlock | Component | ModelOutputThunk] = [] - _log_ctx: list[list[GenerateLog] | None] = [] + current_context = current_context.previous_node # type: ignore + assert current_context is not None, ( + "Previous context cannot be None (except for root context)." + ) - def __init__(self): - """Constructs a basic context.""" - self._ctx = [] - self._log_ctx = [] + context_list.reverse() + return context_list def actions_for_available_tools(self) -> list[Component | CBlock] | None: """Provides a list of actions to extract tools from for use with during generation, or None if that's not possible. - Can be used to make the available tools differ from the tools of all the actions in the context. - In most cases, this will just be the same context as `render_for_generation`. + Can be used to make the available tools differ from the tools of all the actions in the context. Can be overwritten by subclasses. """ - return self.render_for_generation() + return self.view_for_generation() - def last_output(self): + def last_output(self, check_last_n_components: int = 3) -> ModelOutputThunk | None: """The last output thunk of the context.""" - for c in self._ctx[::-1]: + + for c in self.as_list(last_n_components=check_last_n_components)[::-1]: if isinstance(c, ModelOutputThunk): return c return None - @property - def logs(self) -> list[list[GenerateLog] | None]: - """Returns a list of all logs in the context.""" - return list(self._log_ctx) + def last_turn(self): + """The last input/output turn of the context. - def get_logs_by_index(self, index: int) -> list[GenerateLog] | None: - """Returns the log of a given index from the context.""" - try: - return self._log_ctx[index] - except IndexError: - FancyLogger.get_logger().warn(f"Index {index} for logs is out of range") - return None + This can be partial. If the last event is an input, then the output is None. + """ - def last_output_and_logs( - self, all_intermediate_results: bool = False - ) -> tuple[ModelOutputThunk | None, list[GenerateLog] | GenerateLog | None]: - """The last output thunk of the context and the corresponding log.""" - last: ModelOutputThunk | None = None - last_i = 0 - for c in self._ctx[::-1]: - last_i -= 1 - if isinstance(c, ModelOutputThunk): - last = c - break - if last is None: - return None, None - else: - logs = self.get_logs_by_index(last_i) - if all_intermediate_results or logs is None: - # return everything - return last, logs - else: - log = [log for log in logs if log.is_final_result] - # if there is only one log in history, this should be the one. - if len(log) == 0: - if len(logs) == 1: - FancyLogger.get_logger().warn( - f"No final result found for log {logs[0]}. Returning the only result." - ) - log = logs - else: - FancyLogger.get_logger().warn( - f"No final result found for log {logs[0]}. Could not decide which log to return. Returning None." - ) - return last, None - assert len(log) == 1, ( - f"Found multiple/none final results for logs: {len(log)}, " - ) - return last, log[0] - - def full_event_log(self) -> list[Component | CBlock]: - """Returns the underlying _ctx.""" - return self._ctx + history = self.as_list(last_n_components=2) - def last_turn(self): - """The last input/output turn of the context.""" - if len(self._ctx) == 0: + if len(history) == 0: return None - last_element = self._ctx[-1] + last_element = history[-1] if isinstance(last_element, ModelOutputThunk): - if len(self._ctx) >= 2: + if len(history) >= 2: # assuming that the last two elements are input and output - return ContextTurn(self._ctx[-2], last_element) + return ContextTurn(history[-2], last_element) else: # if self._ctx is of size 1 and only element is output element, return partial turn without an input. return ContextTurn(None, last_element) @@ -548,141 +483,46 @@ def last_turn(self): # if the last element is input element, return partial turn without output return ContextTurn(last_element, None) - def __str__(self): - """Pretty prints the context. For debugging.""" - return f"{self.__class__.__name__} := \n" + "\n".join( - [f" {c!s}" for c in self._ctx] - ) + # Abstract methods below this line. - def copy(self): - """Copies all attributes of the Context. `_ctx` and `_log_ctx` are shallow copies. + @abc.abstractmethod + def add(self, c: Component | CBlock) -> Context: + """Returns a new context obtained by adding `c` to this context.""" + # something along ....from_previous(self, c) + ... - This means that the lists are different (you can independently insert to the new/old context), but that the objects in the old/new lists are the same at copy time. - """ - new = copy(self) - new._ctx = copy(self._ctx) - new._log_ctx = copy(self._log_ctx) - return new + @abc.abstractmethod + def view_for_generation(self) -> list[Component | CBlock] | None: + """Provides a linear list of context components to use for generation, or None if that is not possible to construct.""" + ... -class LinearContext(BasicContext): - """Initializes a linear context with unbounded window_size and is_chat=True by default.""" +class ChatContext(Context): + """Initializes a chat context with unbounded window_size and is_chat=True by default.""" - def __init__( - self, - *, - window_size: int | None = None, - log_window_size: int | None = 10, - is_chat_context=True, - ): - """Initializes a linear context with unbounded window_size (log_window_size = 10) and is_chat=True by default.""" + def __init__(self, *, window_size: int | None = None): + """Constructs a new chat context.""" super().__init__() - self.window_size = window_size - self._log_window_size = log_window_size - self.is_chat_context = is_chat_context - - def reset(self): - """Resets the context to a fresh state. - - Note: resetting a context does NOT free memory or clear cache. For this reason, you probably want to be calling this method from a `Session`. - """ - self._ctx = [] - - def insert( - self, - value: CBlock | Component, - *, - key: Any | None = None, - generate_logs: list[GenerateLog] | None = None, - ): - """Inserts into the context and then shifts the window forward if necessary.""" - self._ctx.append(value) - self._log_ctx.append(generate_logs) - if self.window_size is not None and len(self._ctx) > self.window_size: - del self._ctx[0] - if ( - self._log_window_size is not None - and len(self._log_ctx) > self._log_window_size - ): - del self._log_ctx[0] - - def insert_turn( - self, turn: ContextTurn, *, generate_logs: list[GenerateLog] | None = None - ): - """Insert a turn into the context.""" - if turn.model_input: - self.insert(turn.model_input, generate_logs=None) - if turn.output: - self.insert(turn.output, generate_logs=generate_logs) - - def render_for_generation(self) -> list[Component | CBlock] | None: - """Returns the underlying _ctx list for generation.""" - return self._ctx - - def is_chat_history(self): - FancyLogger.get_logger().warning( - "is_chat_history doesn't work properly, because ModelOutputThunks are not Messages." - ) - """Returns true if everything in the LinearContext is a chat `Message`.""" - return all( - str(type(x)) == "Message" for x in self._ctx - ) # sic: avoids circular import. - - def _hash_for_kv_cache(self): - """Constructs a hash that corresponds to the string contents of the KV cache associated with this context.""" - assert False, "not supported yet." + self._window_size = window_size + def add(self, c: Component | CBlock) -> ChatContext: + new = ChatContext.from_previous(self, c) + new._window_size = self._window_size + return new -class SimpleContext(BasicContext): - """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved. + def view_for_generation(self) -> list[Component | CBlock] | None: + return self.as_list(self._window_size) - This context is intended for applications where each LLM call is (mostly) a stand-alone request. Patterns like instruct-validate-repair fall into this category. - It is possible for a single turn to have many different CBlocks/Components. This can happen for a variety of reasons: - 1. Instruct/Repair is actually up to 3 (not 4!) turns: a system, a user, an assistant, and then the ALora output. - 2. It's possible to have a Component with a bunch of other stuff in it. We haven't decided how to represent this in Span world yet, but it's possible that one approach would be to have any causal dependency structure represented in terms of a linearization or poset. - """ +class SimpleContext(Context): + """A `SimpleContext` is a context in which each interaction is a separate and independent turn. The history of all previous turns is NOT saved..""" - def __init__(self): - """Initializes a SimpleContext which contains at max one turn. with is_chat_context=True.""" - super().__init__() - self.is_chat_context = True + def add(self, c: Component | CBlock) -> SimpleContext: + return SimpleContext.from_previous(self, c) - def render_for_generation(self) -> list[Component | CBlock] | None: - """Uses _ctx ordering.""" + def view_for_generation(self) -> list[Component | CBlock] | None: return [] - def reset(self): - """Resets the context to a fresh state. - - Note: resetting a context does NOT free memory or clear cache. For this reason, you probably want to be calling this method from a `Session`. - """ - self._ctx = [] - - def insert( - self, - value: CBlock | Component, - *, - key: Any | None = None, - generate_logs: list[GenerateLog] | None = None, - ): - """Adds the value to the context.""" - assert key is None - self._ctx = [value] - self._log_ctx = [generate_logs] - - def insert_turn( - self, turn: ContextTurn, *, generate_logs: list[GenerateLog] | None = None - ): - """Removes the previous turn and starts a new one.""" - self.reset() - self._ctx = [x for x in [turn.model_input, turn.output] if x] - self._log_ctx = [None, generate_logs] - - def _hash_for_kv_cache(self): - """Constructs a hash that corresponds to the string contents of the KV cache associated with this context.""" - assert False, "not supported yet." - @dataclass class TemplateRepresentation: diff --git a/mellea/stdlib/chat.py b/mellea/stdlib/chat.py index 1b4ef38c..084727b9 100644 --- a/mellea/stdlib/chat.py +++ b/mellea/stdlib/chat.py @@ -129,7 +129,7 @@ def _to_msg(c: CBlock | Component | ModelOutputThunk) -> Message | None: case _: return None - all_ctx_events = ctx.full_event_log() + all_ctx_events = ctx.as_list() if all_ctx_events is None: raise Exception("Trying to cast a non-linear history into a chat history.") else: diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py new file mode 100644 index 00000000..cc78b7d4 --- /dev/null +++ b/mellea/stdlib/funcs.py @@ -0,0 +1,647 @@ +"""Functions for Mellea operations like Instruct, Chat, etc...""" + +from __future__ import annotations + +import asyncio +from collections.abc import Coroutine +from concurrent.futures import ThreadPoolExecutor +from typing import Any, Literal, TypeVar, overload + +from PIL import Image as PILImage + +from mellea.backends import Backend, BaseModelSubclass +from mellea.backends.formatter import FormatterBackend +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import ( + CBlock, + ChatContext, + Component, + Context, + GenerateLog, + ImageBlock, + ModelOutputThunk, + SimpleContext, +) +from mellea.stdlib.chat import Message, ToolMessage +from mellea.stdlib.instruction import Instruction +from mellea.stdlib.mify import mify +from mellea.stdlib.mobject import MObjectProtocol +from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.sampling import ( + RejectionSamplingStrategy, + SamplingResult, + SamplingStrategy, +) + + +@overload +def act( + action: Component, + context: Context, + backend: Backend, + *, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: ... + + +@overload +def act( + action: Component, + context: Context, + backend: Backend, + *, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> SamplingResult: ... + + +def act( + action: Component, + context: Context, + backend: Backend, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context] | SamplingResult: + """Runs a generic action, and adds both the action and the result to the context. + + Args: + action: the Component from which to generate. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + requirements: used as additional requirements when a sampling strategy is provided. + strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: if set, the BaseModel to use for constrained decoding. + model_options: additional model options, which will upsert into the model/backend's defaults. + tool_calls: if true, tool calling is enabled. + + Returns: + A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + + out = _run_async_in_thread( + _act( + action, + context, + backend, + requirements=requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + ) + + return out + + +async def _act( + action: Component, + context: Context, + backend: Backend, + *, + requirements: list[Requirement] | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context] | SamplingResult: + """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. + + Args: + action: the Component from which to generate. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + requirements: used as additional requirements when a sampling strategy is provided + strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: if set, the BaseModel to use for constrained decoding. + model_options: additional model options, which will upsert into the model/backend's defaults. + tool_calls: if true, tool calling is enabled. + + Returns: + A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. + """ + sampling_result: SamplingResult | None = None + generate_logs: list[GenerateLog] = [] + + if return_sampling_results: + assert strategy is not None, ( + "Must provide a SamplingStrategy when return_sampling_results==True" + ) + + # if there is no reason to sample, just generate from the context. + if strategy is None or requirements is None or len(requirements) == 0: + if strategy is None and requirements is not None: + FancyLogger.get_logger().warning( + "Calling the function with NO strategy BUT requirements. No requirement is being checked!" + ) + + result, new_ctx = backend.generate_from_context( + action, + ctx=context, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + await result.avalue() + + # ._generate_log should never be None after generation. + assert result._generate_log is not None + result._generate_log.is_final_result = True + generate_logs.append(result._generate_log) + + else: + # if there is a reason to sample, use the sampling strategy. + + sampling_result = await strategy.sample( + action, + context=context, + backend=backend, + requirements=requirements, + validation_ctx=None, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + + assert sampling_result.sample_generations is not None + for result in sampling_result.sample_generations: + assert result._generate_log is not None # Cannot be None after generation. + generate_logs.append(result._generate_log) + + new_ctx = sampling_result.result_ctx + result = sampling_result.result + assert sampling_result.result._generate_log is not None + assert sampling_result.result._generate_log.is_final_result, ( + "generate logs from the final result returned by the sampling strategy must be marked as final" + ) + + if return_sampling_results: + assert ( + sampling_result is not None + ) # Needed for the type checker but should never happen. + return sampling_result + else: + return result, new_ctx + + +@overload +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1), + return_sampling_results: Literal[False] = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: ... + + +@overload +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1), + return_sampling_results: Literal[True], + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> SamplingResult: ... + + +def instruct( + description: str, + context: Context, + backend: Backend, + *, + images: list[ImageBlock] | list[PILImage.Image] | None = None, + requirements: list[Requirement | str] | None = None, + icl_examples: list[str | CBlock] | None = None, + grounding_context: dict[str, str | CBlock | Component] | None = None, + user_variables: dict[str, str] | None = None, + prefix: str | CBlock | None = None, + output_prefix: str | CBlock | None = None, + strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1), + return_sampling_results: bool = False, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context] | SamplingResult: + """Generates from an instruction. + + Args: + description: The description of the instruction. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + requirements: A list of requirements that the instruction can be validated against. + icl_examples: A list of in-context-learning examples that the instruction can be validated against. + grounding_context: A list of grounding contexts that the instruction can use. They can bind as variables using a (key: str, value: str | ContentBlock) tuple. + user_variables: A dict of user-defined variables used to fill in Jinja placeholders in other parameters. This requires that all other provided parameters are provided as strings. + prefix: A prefix string or ContentBlock to use when generating the instruction. + output_prefix: A string or ContentBlock that defines a prefix for the output generation. Usually you do not need this. + strategy: A SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. + return_sampling_results: attach the (successful and failed) sampling attempts to the results. + format: If set, the BaseModel to use for constrained decoding. + model_options: Additional model options, which will upsert into the model/backend's defaults. + tool_calls: If true, tool calling is enabled. + images: A list of images to be used in the instruction or None if none. + """ + + requirements = [] if requirements is None else requirements + icl_examples = [] if icl_examples is None else icl_examples + grounding_context = dict() if grounding_context is None else grounding_context + + images = _parse_and_clean_image_args(images) + + # All instruction options are forwarded to create a new Instruction object. + i = Instruction( + description=description, + requirements=requirements, + icl_examples=icl_examples, + grounding_context=grounding_context, + user_variables=user_variables, + prefix=prefix, + output_prefix=output_prefix, + images=images, + ) + + return act( + i, + context=context, + backend=backend, + requirements=i.requirements, + strategy=strategy, + return_sampling_results=return_sampling_results, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) # type: ignore[call-overload] + + +def chat( + content: str, + context: Context, + backend: Backend, + *, + role: Message.Role = "user", + images: list[ImageBlock] | list[PILImage.Image] | None = None, + user_variables: dict[str, str] | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[Message, Context]: + """Sends a simple chat message and returns the response. Adds both messages to the Context.""" + if user_variables is not None: + content_resolved = Instruction.apply_user_dict_from_jinja( + user_variables, content + ) + else: + content_resolved = content + images = _parse_and_clean_image_args(images) + user_message = Message(role=role, content=content_resolved, images=images) + + result, new_ctx = act( + user_message, + context=context, + backend=backend, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + parsed_assistant_message = result.parsed_repr + assert isinstance(parsed_assistant_message, Message) + + return parsed_assistant_message, new_ctx + + +def validate( + reqs: Requirement | list[Requirement], + context: Context, + backend: Backend, + *, + output: CBlock | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] + | None = None, # TODO: Can we get rid of gen logs here and in act? + input: CBlock | None = None, +) -> list[ValidationResult]: + """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" + # Run everything in the specific event loop for this session. + + out = _run_async_in_thread( + _validate( + reqs=reqs, + context=context, + backend=backend, + output=output, + format=format, + model_options=model_options, + generate_logs=generate_logs, + input=input, + ) + ) + + # Wait for and return the result. + return out + + +async def _validate( + reqs: Requirement | list[Requirement], + context: Context, + backend: Backend, + *, + output: CBlock | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + generate_logs: list[GenerateLog] | None = None, + input: CBlock | None = None, +) -> list[ValidationResult]: + """Asynchronous version of .validate; validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" + # Turn a solitary requirement in to a list of requirements, and then reqify if needed. + reqs = [reqs] if not isinstance(reqs, list) else reqs + reqs = [Requirement(req) if type(req) is str else req for req in reqs] + if output is None: + validation_target_ctx = context + else: + validation_target_ctx = SimpleContext() + + # Add the input/output to the validation context + if input is not None: + validation_target_ctx = validation_target_ctx.add(input) + validation_target_ctx = validation_target_ctx.add(output) + + rvs: list[ValidationResult] = [] + coroutines: list[Coroutine[Any, Any, ValidationResult]] = [] + + for requirement in reqs: + val_result_co = requirement.validate( + backend, validation_target_ctx, format=format, model_options=model_options + ) + coroutines.append(val_result_co) + + for val_result in await asyncio.gather(*coroutines): + rvs.append(val_result) + + # If the validator utilized a backend to generate a result, attach the corresponding + # info to the generate_logs list. + if generate_logs is not None: + if val_result.thunk is not None: + thunk = val_result.thunk + assert ( + thunk._generate_log is not None + ) # Cannot be None after generation. + generate_logs.append(thunk._generate_log) + else: + # We have to append None here so that the logs line-up. + # TODO: A better solution should be found for this edge case. + # This is the only scenario where ValidationResults are supposed to line + # up with GenerateLogs. + generate_logs.append(None) # type: ignore + + return rvs + + +def query( + obj: Any, + query: str, + context: Context, + backend: Backend, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, +) -> tuple[ModelOutputThunk, Context]: + """Query method for retrieving information from an object. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + query: The string representing the query to be executed against the object. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + format: format for output parsing. + model_options: Model options to pass to the backend. + tool_calls: If true, the model may make tool calls. Defaults to False. + + Returns: + ModelOutputThunk: The result of the query as processed by the backend. + """ + if not isinstance(obj, MObjectProtocol): + obj = mify(obj) + + assert isinstance(obj, MObjectProtocol) + q = obj.get_query_object(query) + + answer = act( + q, + context=context, + backend=backend, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + return answer + + +def transform( + obj: Any, + transformation: str, + context: Context, + backend: Backend, + *, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, +) -> tuple[ModelOutputThunk | Any, Context]: + """Transform method for creating a new object with the transformation applied. + + Args: + obj : The object to be queried. It should be an instance of MObject or can be converted to one if necessary. + transformation: The string representing the query to be executed against the object. + context: the context being used as a history from which to generate the response. + backend: the backend used to generate the response. + + Returns: + ModelOutputThunk|Any: The result of the transformation as processed by the backend. If no tools were called, + the return type will be always be ModelOutputThunk. If a tool was called, the return type will be the return type + of the function called, usually the type of the object passed in. + """ + if not isinstance(obj, MObjectProtocol): + obj = mify(obj) + + assert isinstance(obj, MObjectProtocol) + t = obj.get_transform_object(transformation) + + # Check that your model / backend supports tool calling. + # This might throw an error when tools are provided but can't be handled by one or the other. + transformed, new_ctx = act( + t, + context=context, + backend=backend, + format=format, + model_options=model_options, + tool_calls=True, + ) + + tools = _call_tools(transformed, backend) + + # Transform only supports calling one tool call since it cannot currently synthesize multiple outputs. + # Attempt to choose the best one to call. + chosen_tool: ToolMessage | None = None + if len(tools) == 1: + # Only one function was called. Choose that one. + chosen_tool = tools[0] + + elif len(tools) > 1: + for output in tools: + if type(output._tool_output) is type(obj): + chosen_tool = output + break + + if chosen_tool is None: + chosen_tool = tools[0] + + FancyLogger.get_logger().warning( + f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" + # type: ignore + ) + + if chosen_tool: + # Tell the user the function they should've called if no generated values were added. + if len(chosen_tool._tool.args.keys()) == 0: + FancyLogger.get_logger().warning( + f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" + ) + + new_ctx.add(chosen_tool) + FancyLogger.get_logger().info( + "added a tool message from transform to the context" + ) + return chosen_tool._tool_output, new_ctx + + return transformed, new_ctx + + +def _parse_and_clean_image_args( + images_: list[ImageBlock] | list[PILImage.Image] | None = None, +) -> list[ImageBlock] | None: + images: list[ImageBlock] | None = None + if images_ is not None: + assert isinstance(images_, list), "Images should be a list or None." + + if len(images_) > 0: + if isinstance(images_[0], PILImage.Image): + images = [ + ImageBlock.from_pil_image(i) + for i in images_ + if isinstance(i, PILImage.Image) + ] + else: + images = images_ # type: ignore + assert isinstance(images, list) + assert all(isinstance(i, ImageBlock) for i in images), ( + "All images should be ImageBlocks now." + ) + else: + images = None + return images + + +def _call_tools(result: ModelOutputThunk, backend: Backend) -> list[ToolMessage]: + """Call all the tools requested in a result's tool calls object. + + Returns: + list[ToolMessage]: A list of tool messages that can be empty. + """ + # There might be multiple tool calls returned. + outputs: list[ToolMessage] = [] + tool_calls = result.tool_calls + if tool_calls: + # Call the tools and decide what to do. + for name, tool in tool_calls.items(): + try: + output = tool.call_func() + except Exception as e: + output = e + + content = str(output) + if isinstance(backend, FormatterBackend): + content = backend.formatter.print(output) # type: ignore + + outputs.append( + ToolMessage( + role="tool", + content=content, + tool_output=output, + name=name, + args=tool.args, + tool=tool, + ) + ) + return outputs + + +R = TypeVar("R") + + +def _run_async_in_thread(co: Coroutine[Any, Any, R]) -> R: + """Runs the provided coroutine. + + Checks if an event loop is running in this thread. If one is running in this thread, + we use a separate thread to run the async code. Otherwise, run the code using asyncio.run. + """ + + def run_async(co: Coroutine): + """Helper function to run the coroutine.""" + return asyncio.run(co) + + # Check for a running loop. + loop = None + try: + loop = asyncio.get_running_loop() + except Exception: + pass + + if loop is None: + # We can run it here since there's no currently running event loop. + out = run_async(co) + else: + # We have to run it in a new thread since there's a running event loop. + # Use a ThreadPoolExecutor to more easily extract the result. + with ThreadPoolExecutor(max_workers=1) as exec: + future = exec.submit(run_async, co) + out = future.result() + + return out diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index 0b0b0061..64747e38 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -6,17 +6,17 @@ from copy import copy from typing import Any, overload -from mellea.backends import ( - Backend, - BaseModelSubclass, +from mellea.backends import Backend, BaseModelSubclass +from mellea.backends.aloras import Alora +from mellea.helpers.fancy_logger import FancyLogger +from mellea.stdlib.base import ( CBlock, Component, Context, + GenerateLog, ModelOutputThunk, + TemplateRepresentation, ) -from mellea.backends.aloras import Alora -from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import GenerateLog, ModelOutputThunk, TemplateRepresentation def default_output_to_bool(x: CBlock | str) -> bool: @@ -136,7 +136,7 @@ async def validate( # and its template gets populated with the output correctly. req_copy = copy(self) req_copy._output = last_output.value - llm_as_a_judge_result = backend.generate_from_context( + llm_as_a_judge_result, _ = backend.generate_from_context( req_copy, ctx, format=format, model_options=model_options ) await llm_as_a_judge_result.avalue() @@ -248,7 +248,7 @@ async def validate( # and its template gets populated with the output correctly. req_copy = copy(self) req_copy._output = last_output.value - llm_as_a_judge_result = backend.generate_from_context( + llm_as_a_judge_result, _ = backend.generate_from_context( req_copy, ctx, format=format, model_options=model_options ) await llm_as_a_judge_result.avalue() diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index ec18fcdf..1dff0d6d 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -94,7 +94,7 @@ def _guardian_validate(self, ctx: Context): Code is adopted from https://huggingface.co/ibm-granite/granite-guardian-3.2-3b-a800m#quickstart-example Args: - ctx (Context): The context object containing the last turn of the conversation. + ctx (LegacyContext): The context object containing the last turn of the conversation. Returns: bool: True if there is no identified risk, False otherwise. diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index ff7ab3a2..3dbf6912 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -1,22 +1,15 @@ """sampling methods go here.""" import abc -from collections.abc import Callable, Coroutine from copy import deepcopy -from typing import Any import tqdm -from mellea import LinearContext +import mellea.stdlib.funcs as mfuncs +from mellea.backends import Backend, BaseModelSubclass +from mellea.helpers.async_helpers import wait_for_all_mots from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import ( - CBlock, - Component, - Context, - ContextTurn, - GenerateLog, - ModelOutputThunk, -) +from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult @@ -28,12 +21,14 @@ class SamplingResult(CBlock): def __init__( self, result: ModelOutputThunk, + result_ctx: Context, success: bool, *, sample_generations: list[ModelOutputThunk] | None = None, sample_validations: list[list[tuple[Requirement, ValidationResult]]] | None = None, sample_actions: list[Component] | None = None, + sample_contexts: list[Context] | None = None, ): """Initialize a new instance of sampling results. @@ -45,10 +40,12 @@ def __init__( """ super().__init__(value=result.value) self.result = result + self.result_ctx = result_ctx self.success = success self.sample_generations = sample_generations self.sample_validations = sample_validations self.sample_actions = sample_actions + self.sample_contexts = sample_contexts class SamplingStrategy(abc.ABC): @@ -58,25 +55,18 @@ class SamplingStrategy(abc.ABC): It allows setting custom validation and generation functions through properties. """ - # the function signature here matches that of m.validate - validate: ( - Callable[ - [list[Requirement], Context, Any, Any], - Coroutine[Any, Any, list[ValidationResult]], - ] - | None - ) = None - - generate: Callable[[Component, Context], ModelOutputThunk] | None = None - @abc.abstractmethod async def sample( self, action: Component, context: Context, + backend: Backend, requirements: list[Requirement], *, validation_ctx: Context | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, ) -> SamplingResult: """This method is the abstract method for sampling a given instruction. @@ -96,16 +86,7 @@ class BaseSamplingStrategy(SamplingStrategy): loop_budget: int def __init__( - self, - *, - loop_budget: int = 1, - validate: Callable[ - [list[Requirement], Context, Any, Any], - Coroutine[Any, Any, list[ValidationResult]], - ] - | None = None, - generate: (Callable[[Component, Context], ModelOutputThunk] | None) = None, - requirements: list[Requirement] | None = None, + self, *, loop_budget: int = 1, requirements: list[Requirement] | None = None ): """Initialize a new instance of the class with default parameters. @@ -121,29 +102,29 @@ def __init__( assert loop_budget > 0, "Loop budget must be at least 1." self.loop_budget = loop_budget - self.validate = validate # it's ok to be None here - self.generate = generate # it's ok to be None here self.requirements = requirements @staticmethod @abc.abstractmethod def repair( - ctx: Context, + old_ctx: Context, + new_ctx: Context, past_actions: list[Component], past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> Component: + ) -> tuple[Component, Context]: """ Repair function that is being invoked if not all requirements are fulfilled. It should return a next action component. Args: - ctx: The context to be passed to the sampling strategy. + old_ctx: The context WITHOUT the last action + output. + new_ctx: The context including the last action + output. past_actions: List of actions that have been executed (without success). past_results: List of (unsuccessful) generation results for these actions. past_val: List of validation results for the results. Returns: - The next action component. + The next action component and context to be used for the next generation attempt. """ ... @@ -170,10 +151,14 @@ async def sample( self, action: Component, context: Context, + backend: Backend, requirements: list[Requirement], *, - show_progress: bool = True, validation_ctx: Context | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + show_progress: bool = True, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -190,19 +175,14 @@ async def sample( Raises: AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. """ - assert self.validate is not None, "Validation must be provided." - assert self.generate is not None, "Generate must be provided." - - # just to be sure to not cause issues to the OG context - ctx = context.copy() validation_ctx = validation_ctx if validation_ctx is not None else context - assert validation_ctx is not None, "Validation context must be provided." flog = FancyLogger.get_logger() sampled_results: list[ModelOutputThunk] = [] sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress # flag to determine whether we should show the pbar. @@ -224,22 +204,32 @@ async def sample( else range(self.loop_budget) # type: ignore ) - new_action = deepcopy(action) + next_action = deepcopy(action) + next_context = context for _ in loop_budget_range_iterator: # type: ignore loop_count += 1 if not show_progress: flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result = self.generate(new_action, ctx) + result, result_ctx = backend.generate_from_context( + next_action, + ctx=next_context, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) await result.avalue() # validation pass - val_scores_co = self.validate( - reqs, - validation_ctx, - result, - input=None, # type: ignore + val_scores_co = mfuncs._validate( + reqs=reqs, + context=result_ctx, + backend=backend, + output=result, + format=format, + model_options=model_options, + # tool_calls=tool_calls # Don't support using tool calls in validation strategies. ) val_scores = await val_scores_co @@ -249,7 +239,8 @@ async def sample( # collect all data sampled_results.append(result) sampled_scores.append(constraint_scores) - sampled_actions.append(new_action) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) # if all vals are true -- break and return success if all(bool(s[1]) for s in constraint_scores): @@ -259,11 +250,15 @@ async def sample( ) # Cannot be None after generation. result._generate_log.is_final_result = True + # SUCCESS !!!! return SamplingResult( - result, + result=result, + result_ctx=result_ctx, success=True, sample_generations=sampled_results, sample_validations=sampled_scores, + sample_contexts=sample_contexts, + sample_actions=sampled_actions, ) else: @@ -272,8 +267,12 @@ async def sample( flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") # If we did not pass all constraints, update the instruction and try again. - new_action = self.repair( - ctx, sampled_actions, sampled_results, sampled_scores + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, ) flog.info( @@ -294,11 +293,13 @@ async def sample( sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore return SamplingResult( - sampled_results[best_failed_index], + result=sampled_results[best_failed_index], + result_ctx=sample_contexts[best_failed_index], success=False, sample_generations=sampled_results, sample_validations=sampled_scores, sample_actions=sampled_actions, + sample_contexts=sample_contexts, ) @@ -316,13 +317,14 @@ def select_from_failure( @staticmethod def repair( - ctx: Context, + old_ctx: Context, + new_ctx: Context, past_actions: list[Component], past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> Component: + ) -> tuple[Component, Context]: # repeat the last action again. - return past_actions[-1] + return past_actions[-1], old_ctx class RepairTemplateStrategy(BaseSamplingStrategy): @@ -339,11 +341,12 @@ def select_from_failure( @staticmethod def repair( - ctx: Context, + old_ctx: Context, + new_ctx: Context, past_actions: list[Component], past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> Component: + ) -> tuple[Component, Context]: pa = past_actions[-1] if isinstance(pa, Instruction): last_failed_reqs: list[Requirement] = [ @@ -354,8 +357,8 @@ def repair( ) return pa.copy_and_repair( repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" - ) - return past_actions[-1] + ), old_ctx + return pa, old_ctx class MultiTurnStrategy(BaseSamplingStrategy): @@ -372,18 +375,16 @@ def select_from_failure( @staticmethod def repair( - ctx: Context, + old_ctx: Context, + new_ctx: Context, past_actions: list[Component], past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> Component: - assert isinstance(ctx, LinearContext), ( - " Need linear context to run agentic sampling." + ) -> tuple[Component, Context]: + assert isinstance(new_ctx, ChatContext), ( + " Need chat context to run agentic sampling." ) - # add failed execution to chat history - ctx.insert_turn(ContextTurn(past_actions[-1], past_results[-1])) - last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]] last_failed_reqs_str = "* " + "\n* ".join( [str(r.description) for r in last_failed_reqs] @@ -395,7 +396,7 @@ def repair( content=f"The following requirements have not been met: \n{last_failed_reqs_str}\n Please try again to fulfill the requirements.", ) - return next_action + return next_action, new_ctx class BestofNSamplingStrategy(BaseSamplingStrategy): @@ -407,10 +408,14 @@ async def sample( self, action: Component, context: Context, + backend: Backend, requirements: list[Requirement], *, - show_progress: bool = True, validation_ctx: Context | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + show_progress: bool = True, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -427,11 +432,6 @@ async def sample( Raises: AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. """ - assert self.validate is not None, "Validation must be provided." - assert self.generate is not None, "Generate must be provided." - - # just to be sure to not cause issues to the OG context - ctx = context.copy() validation_ctx = validation_ctx if validation_ctx is not None else context assert validation_ctx is not None, "Validation context must be provided." @@ -440,12 +440,12 @@ async def sample( sampled_results: list[ModelOutputThunk] = [] sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] sampled_actions: list[Component] = [] + sample_contexts: list[Context] = [] successful_sampled_results: list[ModelOutputThunk] = [] successful_sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] successful_sampled_actions: list[Component] = [] - - # sampled_val_scores: list[float] = [] + successful_sample_contexts: list[Context] = [] # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress # flag to determine whether we should show the pbar. @@ -471,29 +471,54 @@ async def sample( ) loop_count = 0 - loop_budget_range_iterator = ( + generate_loop_budget_iterator = ( + tqdm.tqdm(range(self.loop_budget)) # type: ignore + if show_progress + else range(self.loop_budget) # type: ignore + ) + validate_loop_budget_iterator = ( tqdm.tqdm(range(self.loop_budget)) # type: ignore if show_progress else range(self.loop_budget) # type: ignore ) - new_action = deepcopy(action) - for _ in loop_budget_range_iterator: # type: ignore + next_action = deepcopy(action) + next_context = context + flog.info("BestofNSampling Generating Loop:") + for _ in generate_loop_budget_iterator: # type: ignore loop_count += 1 if not show_progress: flog.info(f"Running loop {loop_count} of {self.loop_budget}") # run a generation pass - result = self.generate(new_action, ctx) - await result.avalue() - - # validation pass - # action has user turn - val_scores_co = self.validate( - reqs, - validation_ctx, - result, - input=action._description, # type: ignore + result, result_ctx = backend.generate_from_context( + next_action, + ctx=next_context, + format=format, + model_options=model_options, + tool_calls=tool_calls, + ) + sampled_results.append(result) + sampled_actions.append(next_action) + sample_contexts.append(result_ctx) + + await wait_for_all_mots(sampled_results) + + flog.info("BestofNSampling Validation Loop:") + for i in validate_loop_budget_iterator: + result_ctx = sample_contexts[i] + result = sampled_results[i] + next_action = sampled_actions[i] + + val_scores_co = mfuncs._validate( + reqs=reqs, + context=result_ctx, + backend=backend, + output=result, + format=format, + model_options=model_options, + input=next_action._description, # type: ignore + # tool_calls=tool_calls # Don't support using tool calls in validation strategies. ) val_scores = await val_scores_co @@ -501,9 +526,7 @@ async def sample( constraint_scores = list(zip(reqs, val_scores)) # collect all data - sampled_results.append(result) sampled_scores.append(constraint_scores) - sampled_actions.append(new_action) # check if requirements pass else repair and re-sample # if all vals are true, save it and continue to get next sample @@ -516,7 +539,8 @@ async def sample( successful_sampled_results.append(result) successful_sampled_scores.append(constraint_scores) - successful_sampled_actions.append(new_action) + successful_sampled_actions.append(next_action) + successful_sample_contexts.append(result_ctx) else: # log partial success and continue @@ -524,8 +548,12 @@ async def sample( flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") # If we did not pass all constraints, update the instruction and try again. - new_action = self.repair( - ctx, sampled_actions, sampled_results, sampled_scores + next_action, next_context = self.repair( + next_context, + result_ctx, + sampled_actions, + sampled_results, + sampled_scores, ) # find max reward amongst results for which all requirements have passed @@ -544,22 +572,26 @@ async def sample( assert scorer_preference_ordering is not None if scorer_preference_ordering == "max": - best_result, best_score = max( - zip(successful_sampled_results, scores), key=lambda x: x[1] + best_result, best_score, best_context = max( + zip(successful_sampled_results, scores, successful_sample_contexts), + key=lambda x: x[1], ) elif scorer_preference_ordering == "min": - best_result, best_score = min( - zip(successful_sampled_results, scores), key=lambda x: x[1] + best_result, best_score, best_context = min( + zip(successful_sampled_results, scores, successful_sample_contexts), + key=lambda x: x[1], ) else: raise NotImplementedError return SamplingResult( best_result, + result_ctx=best_context, success=True, sample_generations=sampled_results, sample_validations=sampled_scores, sample_actions=sampled_actions, + sample_contexts=sample_contexts, ) # if all failures, call select from failure @@ -577,10 +609,12 @@ async def sample( ) return SamplingResult( sampled_results[best_failed_index], + result_ctx=sample_contexts[best_failed_index], success=False, sample_generations=sampled_results, sample_validations=sampled_scores, sample_actions=sampled_actions, + sample_contexts=sample_contexts, ) @staticmethod @@ -605,11 +639,12 @@ def select_from_failure( @staticmethod def repair( - ctx: Context, + old_ctx: Context, + new_ctx: Context, past_actions: list[Component], past_results: list[ModelOutputThunk], past_val: list[list[tuple[Requirement, ValidationResult]]], - ) -> Component: + ) -> tuple[Component, Context]: pa = past_actions[-1] if isinstance(pa, Instruction): last_failed_reqs: list[Requirement] = [ @@ -620,5 +655,5 @@ def repair( ) return pa.copy_and_repair( repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" - ) - return past_actions[-1] + ), old_ctx + return past_actions[-1], old_ctx diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 47bf0581..ef55c9f8 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -2,22 +2,14 @@ from __future__ import annotations -import asyncio import contextvars -import threading -from collections.abc import Coroutine -from copy import deepcopy from typing import Any, Literal, overload from PIL import Image as PILImage +import mellea.stdlib.funcs as mfuncs from mellea.backends import Backend, BaseModelSubclass -from mellea.backends.formatter import FormatterBackend -from mellea.backends.model_ids import ( - IBM_GRANITE_3_2_8B, - IBM_GRANITE_3_3_8B, - ModelIdentifier, -) +from mellea.backends.model_ids import IBM_GRANITE_3_3_8B, ModelIdentifier from mellea.backends.ollama import OllamaModelBackend from mellea.backends.openai import OpenAIBackend from mellea.helpers.fancy_logger import FancyLogger @@ -25,18 +17,13 @@ CBlock, Component, Context, - ContextTurn, GenerateLog, ImageBlock, - LinearContext, ModelOutputThunk, SimpleContext, ) -from mellea.stdlib.chat import Message, ToolMessage -from mellea.stdlib.instruction import Instruction -from mellea.stdlib.mify import mify -from mellea.stdlib.mobject import MObjectProtocol -from mellea.stdlib.requirement import Requirement, ValidationResult, check, req +from mellea.stdlib.chat import Message +from mellea.stdlib.requirement import Requirement, ValidationResult from mellea.stdlib.sampling import SamplingResult, SamplingStrategy # Global context variable for the context session @@ -84,7 +71,7 @@ def backend_name_to_class(name: str) -> Any: def start_session( backend_name: Literal["ollama", "hf", "openai", "watsonx", "litellm"] = "ollama", model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B, - ctx: Context | None = SimpleContext(), + ctx: Context | None = None, *, model_options: dict | None = None, **backend_kwargs, @@ -103,10 +90,11 @@ def start_session( - "hf" or "huggingface": Use HuggingFace transformers backend - "openai": Use OpenAI API backend - "watsonx": Use IBM WatsonX backend + - "litellm": Use the LiteLLM backend model_id: Model identifier or name. Can be a `ModelIdentifier` from mellea.backends.model_ids or a string model name. ctx: Context manager for conversation history. Defaults to SimpleContext(). - Use LinearContext() for chat-style conversations. + Use ChatContext() for chat-style conversations. model_options: Additional model configuration options that will be passed to the backend (e.g., temperature, max_tokens, etc.). **backend_kwargs: Additional keyword arguments passed to the backend constructor. @@ -137,9 +125,9 @@ def start_session( with start_session("openai", "gpt-4", model_options={"temperature": 0.7}): response = chat("Write a poem") - # Using HuggingFace with LinearContext for conversations - from mellea.stdlib.base import LinearContext - with start_session("hf", "microsoft/DialoGPT-medium", ctx=LinearContext()): + # Using HuggingFace with ChatContext for conversations + from mellea.stdlib.base import ChatContext + with start_session("hf", "microsoft/DialoGPT-medium", ctx=ChatContext()): chat("Hello!") chat("How are you?") # Remembers previous message @@ -155,6 +143,9 @@ def start_session( ) assert backend_class is not None backend = backend_class(model_id, model_options=model_options, **backend_kwargs) + + if ctx is None: + ctx = SimpleContext() return MelleaSession(backend, ctx) @@ -172,6 +163,8 @@ class MelleaSession: Note: we put the `instruct`, `validate`, and other convenience functions here instead of in `Context` or `Backend` to avoid import resolution issues. """ + ctx: Context + def __init__(self, backend: Backend, ctx: Context | None = None): """Initializes a new Mellea session with the provided backend and context. @@ -181,15 +174,11 @@ def __init__(self, backend: Backend, ctx: Context | None = None): model_options (Optional[dict]): model options, which will upsert into the model/backend's defaults. """ self.backend = backend - self.ctx = ctx if ctx is not None else SimpleContext() + self.ctx: Context = ctx if ctx is not None else SimpleContext() self._backend_stack: list[tuple[Backend, dict | None]] = [] self._session_logger = FancyLogger.get_logger() self._context_token = None - # Necessary for async. `m.*` functions should always run in this event loop. - self._event_loop = asyncio.new_event_loop() - threading.Thread(target=self._event_loop.run_forever, daemon=True).start() - def __enter__(self): """Enter context manager and set this session as the current global session.""" self._context_token = _context_session.set(self) @@ -202,9 +191,6 @@ def __exit__(self, exc_type, exc_val, exc_tb): _context_session.reset(self._context_token) self._context_token = None - def __del__(self): - self._close_event_loop() - def _push_model_state(self, new_backend: Backend, new_model_opts: dict): """The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`. @@ -232,69 +218,15 @@ def _pop_model_state(self) -> bool: def reset(self): """Reset the context state.""" - self.ctx.reset() + self.ctx = self.ctx.reset_to_new() def cleanup(self) -> None: """Clean up session resources.""" - self._close_event_loop() self.reset() self._backend_stack.clear() if hasattr(self.backend, "close"): self.backend.close() # type: ignore - def _close_event_loop(self) -> None: - """Called when deleting the session. Cleans up the session's event loop.""" - if self._event_loop: - try: - tasks = asyncio.all_tasks(self._event_loop) - for task in tasks: - task.cancel() - - async def finalize_tasks(): - # TODO: We can log errors here if needed. - await asyncio.gather(*tasks, return_exceptions=True) - - out = asyncio.run_coroutine_threadsafe( - finalize_tasks(), self._event_loop - ) - - # Timeout if needed. - out.result(5) - except Exception: - pass - - # Finally stop the event loop for this session. - self._event_loop.stop() - - def summarize(self) -> ModelOutputThunk: - """Summarizes the current context.""" - raise NotImplementedError() - - @staticmethod - def _parse_and_clean_image_args( - images_: list[ImageBlock] | list[PILImage.Image] | None = None, - ) -> list[ImageBlock] | None: - images: list[ImageBlock] | None = None - if images_ is not None: - assert isinstance(images_, list), "Images should be a list or None." - - if len(images_) > 0: - if isinstance(images_[0], PILImage.Image): - images = [ - ImageBlock.from_pil_image(i) - for i in images_ - if isinstance(i, PILImage.Image) - ] - else: - images = images_ # type: ignore - assert isinstance(images, list) - assert all(isinstance(i, ImageBlock) for i in images), ( - "All images should be ImageBlocks now." - ) - else: - images = None - return images - @overload def act( self, @@ -319,17 +251,7 @@ def act( tool_calls: bool = False, ) -> SamplingResult: ... - def act( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = None, - return_sampling_results: bool = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: + def act(self, action: Component, **kwargs) -> ModelOutputThunk | SamplingResult: """Runs a generic action, and adds both the action and the result to the context. Args: @@ -345,120 +267,11 @@ def act( A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. """ - # Run everything in the specific event loop for this session. - out = asyncio.run_coroutine_threadsafe( - self._act( - action, - requirements=requirements, - strategy=strategy, - return_sampling_results=return_sampling_results, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ), - self._event_loop, + result, context = mfuncs.act( + action, context=self.ctx, backend=self.backend, **kwargs ) - - return out.result() - - async def _act( - self, - action: Component, - *, - requirements: list[Requirement] | None = None, - strategy: SamplingStrategy | None = None, - return_sampling_results: bool = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: - """Asynchronous version of .act; runs a generic action, and adds both the action and the result to the context. - - Args: - action: the Component from which to generate. - requirements: used as additional requirements when a sampling strategy is provided - strategy: a SamplingStrategy that describes the strategy for validating and repairing/retrying for the instruct-validate-repair pattern. None means that no particular sampling strategy is used. - return_sampling_results: attach the (successful and failed) sampling attempts to the results. - format: if set, the BaseModel to use for constrained decoding. - model_options: additional model options, which will upsert into the model/backend's defaults. - tool_calls: if true, tool calling is enabled. - - Returns: - A ModelOutputThunk if `return_sampling_results` is `False`, else returns a `SamplingResult`. - """ - sampling_result: SamplingResult | None = None - generate_logs: list[GenerateLog] = [] - - if return_sampling_results: - assert strategy is not None, ( - "Must provide a SamplingStrategy when return_sampling_results==True" - ) - - if strategy is None: - result = self.backend.generate_from_context( - action, - ctx=self.ctx, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - await result.avalue() - - # ._generate_log should never be None after generation. - assert result._generate_log is not None - result._generate_log.is_final_result = True - generate_logs.append(result._generate_log) - - else: - # Default validation strategy just validates all of the provided requirements. - if strategy.validate is None: - strategy.validate = ( - lambda reqs, val_ctx, output, input=None: self._validate( # type: ignore - reqs, output=output, input=input - ) - ) - - # Default generation strategy just generates from context. - if strategy.generate is None: - strategy.generate = ( - lambda sample_action, gen_ctx: self.backend.generate_from_context( - sample_action, - ctx=gen_ctx, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) - ) - - if requirements is None: - requirements = [] - - sampling_result = await strategy.sample( - action, self.ctx, requirements=requirements - ) - - assert sampling_result.sample_generations is not None - for result in sampling_result.sample_generations: - assert ( - result._generate_log is not None - ) # Cannot be None after generation. - generate_logs.append(result._generate_log) - - result = sampling_result.result - assert sampling_result.result._generate_log is not None - assert sampling_result.result._generate_log.is_final_result, ( - "generate logs from the final result returned by the sampling strategy must be marked as final" - ) - - self.ctx.insert_turn(ContextTurn(action, result), generate_logs=generate_logs) - - if return_sampling_results: - assert ( - sampling_result is not None - ) # Needed for the type checker but should never happen. - return sampling_result - else: - return result + self.ctx = context + return result @overload def instruct( @@ -498,23 +311,7 @@ def instruct( tool_calls: bool = False, ) -> SamplingResult: ... - def instruct( - self, - description: str, - *, - images: list[ImageBlock] | list[PILImage.Image] | None = None, - requirements: list[Requirement | str] | None = None, - icl_examples: list[str | CBlock] | None = None, - grounding_context: dict[str, str | CBlock | Component] | None = None, - user_variables: dict[str, str] | None = None, - prefix: str | CBlock | None = None, - output_prefix: str | CBlock | None = None, - strategy: SamplingStrategy | None = None, - return_sampling_results: bool = False, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - tool_calls: bool = False, - ) -> ModelOutputThunk | SamplingResult: + def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingResult: """Generates from an instruction. Args: @@ -533,33 +330,17 @@ def instruct( images: A list of images to be used in the instruction or None if none. """ - requirements = [] if requirements is None else requirements - icl_examples = [] if icl_examples is None else icl_examples - grounding_context = dict() if grounding_context is None else grounding_context - - images = self._parse_and_clean_image_args(images) - - # All instruction options are forwarded to create a new Instruction object. - i = Instruction( - description=description, - requirements=requirements, - icl_examples=icl_examples, - grounding_context=grounding_context, - user_variables=user_variables, - prefix=prefix, - output_prefix=output_prefix, - images=images, + r = mfuncs.instruct( + description, context=self.ctx, backend=self.backend, **kwargs ) - return self.act( - i, - requirements=i.requirements, - strategy=strategy, - return_sampling_results=return_sampling_results, - format=format, - model_options=model_options, - tool_calls=tool_calls, - ) # type: ignore[call-overload] + if isinstance(r, SamplingResult): + self.ctx = r.result_ctx + return r + else: + result, context = r + self.ctx = context + return result def chat( self, @@ -573,25 +354,21 @@ def chat( tool_calls: bool = False, ) -> Message: """Sends a simple chat message and returns the response. Adds both messages to the Context.""" - if user_variables is not None: - content_resolved = Instruction.apply_user_dict_from_jinja( - user_variables, content - ) - else: - content_resolved = content - images = self._parse_and_clean_image_args(images) - user_message = Message(role=role, content=content_resolved, images=images) - result = self.act( - user_message, + result, context = mfuncs.chat( + content=content, + context=self.ctx, + backend=self.backend, + role=role, + images=images, + user_variables=user_variables, format=format, model_options=model_options, tool_calls=tool_calls, ) - parsed_assistant_message = result.parsed_repr - assert isinstance(parsed_assistant_message, Message) - return parsed_assistant_message + self.ctx = context + return result def validate( self, @@ -604,85 +381,17 @@ def validate( input: CBlock | None = None, ) -> list[ValidationResult]: """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - # Run everything in the specific event loop for this session. - out = asyncio.run_coroutine_threadsafe( - self._validate( - reqs=reqs, - output=output, - format=format, - model_options=model_options, - generate_logs=generate_logs, - input=input, - ), - self._event_loop, - ) - # Wait for and return the result. - return out.result() - - async def _validate( - self, - reqs: Requirement | list[Requirement], - *, - output: CBlock | None = None, - format: type[BaseModelSubclass] | None = None, - model_options: dict | None = None, - generate_logs: list[GenerateLog] | None = None, - input: CBlock | None = None, - ) -> list[ValidationResult]: - """Asynchronous version of .validate; validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" - # Turn a solitary requirement in to a list of requirements, and then reqify if needed. - reqs = [reqs] if not isinstance(reqs, list) else reqs - reqs = [Requirement(req) if type(req) is str else req for req in reqs] - if output is None: - validation_target_ctx = self.ctx - else: - validation_target_ctx = SimpleContext() - - if input is not None: - # some validators may need input as well as output - validation_target_ctx.insert_turn( - ContextTurn( - input, - output, # type: ignore - ), # type: ignore - generate_logs=generate_logs, - ) - else: - validation_target_ctx.insert(output) - - rvs: list[ValidationResult] = [] - coroutines: list[Coroutine[Any, Any, ValidationResult]] = [] - - for requirement in reqs: - val_result_co = requirement.validate( - self.backend, - validation_target_ctx, - format=format, - model_options=model_options, - ) - coroutines.append(val_result_co) - - for val_result in await asyncio.gather(*coroutines): - rvs.append(val_result) - - # If the validator utilized a backend to generate a result, attach the corresponding - # info to the generate_logs list. - if generate_logs is not None: - if val_result.thunk is not None: - thunk = val_result.thunk - assert ( - thunk._generate_log is not None - ) # Cannot be None after generation. - generate_logs.append(thunk._generate_log) - else: - # We have to append None here so that the logs line-up. - # TODO: A better solution should be found for this edge case. - # This is the only scenario where ValidationResults are supposed to line - # up with GenerateLogs. - generate_logs.append(None) # type: ignore - - return rvs + return mfuncs.validate( + reqs=reqs, + context=self.ctx, + backend=self.backend, + output=output, + format=format, + model_options=model_options, + generate_logs=generate_logs, + input=input, + ) def query( self, @@ -705,16 +414,17 @@ def query( Returns: ModelOutputThunk: The result of the query as processed by the backend. """ - if not isinstance(obj, MObjectProtocol): - obj = mify(obj) - - assert isinstance(obj, MObjectProtocol) - q = obj.get_query_object(query) - - answer = self.act( - q, format=format, model_options=model_options, tool_calls=tool_calls + result, context = mfuncs.query( + obj=obj, + query=query, + context=self.ctx, + backend=self.backend, + format=format, + model_options=model_options, + tool_calls=tool_calls, ) - return answer + self.ctx = context + return result def transform( self, @@ -735,132 +445,35 @@ def transform( the return type will be always be ModelOutputThunk. If a tool was called, the return type will be the return type of the function called, usually the type of the object passed in. """ - if not isinstance(obj, MObjectProtocol): - obj = mify(obj) - - assert isinstance(obj, MObjectProtocol) - t = obj.get_transform_object(transformation) - - # Check that your model / backend supports tool calling. - # This might throw an error when tools are provided but can't be handled by one or the other. - transformed = self.act( - t, format=format, model_options=model_options, tool_calls=True + result, context = mfuncs.transform( + obj=obj, + transformation=transformation, + context=self.ctx, + backend=self.backend, + format=format, + model_options=model_options, ) - - tools = self._call_tools(transformed) - - # Transform only supports calling one tool call since it cannot currently synthesize multiple outputs. - # Attempt to choose the best one to call. - chosen_tool: ToolMessage | None = None - if len(tools) == 1: - # Only one function was called. Choose that one. - chosen_tool = tools[0] - - elif len(tools) > 1: - for output in tools: - if type(output._tool_output) is type(obj): - chosen_tool = output - break - - if chosen_tool is None: - chosen_tool = tools[0] - - FancyLogger.get_logger().warning( - f"multiple tool calls returned in transform of {obj} with description '{transformation}'; picked `{chosen_tool.name}`" - # type: ignore - ) - - if chosen_tool: - # Tell the user the function they should've called if no generated values were added. - if len(chosen_tool._tool.args.keys()) == 0: - FancyLogger.get_logger().warning( - f"the transform of {obj} with transformation description '{transformation}' resulted in a tool call with no generated arguments; consider calling the function `{chosen_tool._tool.name}` directly" - ) - - self.ctx.insert(chosen_tool) - FancyLogger.get_logger().info( - "added a tool message from transform to the context" - ) - return chosen_tool._tool_output - - return transformed - - def _call_tools(self, result: ModelOutputThunk) -> list[ToolMessage]: - """Call all the tools requested in a result's tool calls object. - - Returns: - list[ToolMessage]: A list of tool messages that can be empty. - """ - # There might be multiple tool calls returned. - outputs: list[ToolMessage] = [] - tool_calls = result.tool_calls - if tool_calls: - # Call the tools and decide what to do. - for name, tool in tool_calls.items(): - try: - output = tool.call_func() - except Exception as e: - output = e - - content = str(output) - if isinstance(self.backend, FormatterBackend): - content = self.backend.formatter.print(output) # type: ignore - - outputs.append( - ToolMessage( - role="tool", - content=content, - tool_output=output, - name=name, - args=tool.args, - tool=tool, - ) - ) - return outputs + self.ctx = context + return result # ############################### # Convenience functions # ############################### - def last_prompt(self) -> str | list[dict] | None: """Returns the last prompt that has been called from the session context. Returns: A string if the last prompt was a raw call to the model OR a list of messages (as role-msg-dicts). Is None if none could be found. """ - _, log = self.ctx.last_output_and_logs() - prompt = None + op = self.ctx.last_output() + if op is None: + return None + log = op._generate_log if isinstance(log, GenerateLog): - prompt = log.prompt + return log.prompt elif isinstance(log, list): last_el = log[-1] if isinstance(last_el, GenerateLog): - prompt = last_el.prompt - return prompt - - -# Convenience functions that use the current session -def instruct(description: str, **kwargs) -> ModelOutputThunk | SamplingResult: - """Instruct using the current session.""" - return get_session().instruct(description, **kwargs) - - -def chat(content: str, **kwargs) -> Message: - """Chat using the current session.""" - return get_session().chat(content, **kwargs) - - -def validate(reqs, **kwargs): - """Validate using the current session.""" - return get_session().validate(reqs, **kwargs) - - -def query(obj: Any, query_str: str, **kwargs) -> ModelOutputThunk: - """Query using the current session.""" - return get_session().query(obj, query_str, **kwargs) - - -def transform(obj: Any, transformation: str, **kwargs): - """Transform using the current session.""" - return get_session().transform(obj, transformation, **kwargs) + return last_el.prompt + return None diff --git a/test/backends/test_huggingface.py b/test/backends/test_huggingface.py index 9f85e1bb..4e3ab441 100644 --- a/test/backends/test_huggingface.py +++ b/test/backends/test_huggingface.py @@ -9,7 +9,7 @@ from mellea.backends.formatter import TemplateFormatter from mellea.backends.huggingface import LocalHFBackend from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, LinearContext, SimpleContext +from mellea.stdlib.base import CBlock, ChatContext, SimpleContext from mellea.stdlib.requirement import ( ALoraRequirement, LLMaJRequirement, @@ -34,7 +34,7 @@ def backend(): @pytest.fixture(scope="function") def session(backend): """Fresh HuggingFace session for each test.""" - session = MelleaSession(backend, ctx=LinearContext()) + session = MelleaSession(backend, ctx=ChatContext()) yield session session.reset() @@ -145,7 +145,8 @@ def test_instruct(session): def test_multiturn(session): session.instruct("Compute 1+1") beta = session.instruct( - "Take the result of the previous sum and find the corresponding letter in the greek alphabet." + "Take the result of the previous sum and find the corresponding letter in the greek alphabet.", + model_options={ModelOption.MAX_NEW_TOKENS: 300}, ) assert "β" in str(beta).lower() words = session.instruct("Now list five English words that start with that letter.") @@ -228,8 +229,8 @@ class Answer(pydantic.BaseModel): def test_async_parallel_requests(session): async def parallel_requests(): model_opts = {ModelOption.STREAM: True} - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) m1_val = None m2_val = None @@ -237,7 +238,7 @@ async def parallel_requests(): m1_val = await mot1.astream() if not mot2.is_computed(): m2_val = await mot2.astream() - + assert m1_val is not None, "should be a string val after generation" assert m2_val is not None, "should be a string val after generation" @@ -256,7 +257,7 @@ async def parallel_requests(): @pytest.mark.qualitative def test_async_avalue(session): async def avalue(): - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) m1_final_val = await mot1.avalue() assert m1_final_val is not None assert m1_final_val == mot1.value diff --git a/test/backends/test_litellm_ollama.py b/test/backends/test_litellm_ollama.py index 79442f72..7999e4cf 100644 --- a/test/backends/test_litellm_ollama.py +++ b/test/backends/test_litellm_ollama.py @@ -23,7 +23,7 @@ def test_litellm_ollama_chat(session): assert res is not None assert isinstance(res, Message) assert "2" in res.content, ( - f"Expected a message with content containing 2 but found {output_message}" + f"Expected a message with content containing 2 but found {res}" ) @pytest.mark.qualitative @@ -57,7 +57,7 @@ def test_litellm_ollama_instruct_options(session): assert isinstance(res.value, str) # make sure that homer_simpson is in the logged model_options - assert "homer_simpson" in session.ctx.last_output_and_logs()[1].model_options + assert "homer_simpson" in res._generate_log.model_options # make sure the backend function filters out the model option when passing to the generate call backend = session.backend @@ -81,8 +81,8 @@ def is_happy(text: str) -> bool: def test_async_parallel_requests(session): async def parallel_requests(): model_opts = {ModelOption.STREAM: True} - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) m1_val = None m2_val = None @@ -90,7 +90,7 @@ async def parallel_requests(): m1_val = await mot1.astream() if not mot2.is_computed(): m2_val = await mot2.astream() - + assert m1_val is not None, "should be a string val after generation" assert m2_val is not None, "should be a string val after generation" @@ -109,7 +109,7 @@ async def parallel_requests(): @pytest.mark.qualitative def test_async_avalue(session): async def avalue(): - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) m1_final_val = await mot1.avalue() assert m1_final_val is not None assert m1_final_val == mot1.value diff --git a/test/backends/test_ollama.py b/test/backends/test_ollama.py index 8e9b8631..806747f4 100644 --- a/test/backends/test_ollama.py +++ b/test/backends/test_ollama.py @@ -5,10 +5,10 @@ import pytest from typing_extensions import Annotated -from mellea import SimpleContext, start_session +from mellea import start_session from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock -from mellea.stdlib.requirement import Requirement +from mellea.stdlib.base import CBlock, SimpleContext +from mellea.stdlib.requirement import Requirement, simple_validate @pytest.fixture(scope="function") @@ -37,7 +37,7 @@ def test_instruct_with_requirement(session): email_word_count_req = Requirement( f"The email should be at most 100", - validation_fn=lambda x: len(" ".split(x.last_output().value)) <= 100, + validation_fn=simple_validate(lambda x: len(" ".split(x)) <= 100), ) happy_tone_req = Requirement( @@ -52,6 +52,7 @@ def test_instruct_with_requirement(session): ) print(results) + @pytest.mark.qualitative def test_chat(session): output_message = session.chat("What is 1+1?") @@ -59,6 +60,7 @@ def test_chat(session): f"Expected a message with content containing 2 but found {output_message}" ) + @pytest.mark.qualitative def test_format(session): class Person(pydantic.BaseModel): @@ -78,7 +80,7 @@ class Email(pydantic.BaseModel): output = session.instruct( "Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ", format=Email, - model_options={ModelOption.MAX_NEW_TOKENS: 2**8}, + model_options={ModelOption.MAX_NEW_TOKENS: 2 ** 8}, ) print("Formatted output:") email = Email.model_validate_json( @@ -92,6 +94,7 @@ class Email(pydantic.BaseModel): # assert email.to.email_address.endswith("example.com") pass + @pytest.mark.qualitative def test_generate_from_raw(session): prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"] @@ -127,11 +130,14 @@ class Answer(pydantic.BaseModel): f"formatting directive failed for {random_result.value}: {e.json()}" ) + def test_async_parallel_requests(session): async def parallel_requests(): model_opts = {ModelOption.STREAM: True} - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), + model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), + model_options=model_opts) m1_val = None m2_val = None @@ -139,7 +145,7 @@ async def parallel_requests(): m1_val = await mot1.astream() if not mot2.is_computed(): m2_val = await mot2.astream() - + assert m1_val is not None, "should be a string val after generation" assert m2_val is not None, "should be a string val after generation" @@ -153,15 +159,19 @@ async def parallel_requests(): assert m1_final_val == mot1.value assert m2_final_val == mot2.value + asyncio.run(parallel_requests()) + def test_async_avalue(session): async def avalue(): - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) m1_final_val = await mot1.avalue() assert m1_final_val is not None assert m1_final_val == mot1.value + asyncio.run(avalue()) + if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/backends/test_openai_ollama.py b/test/backends/test_openai_ollama.py index d773e645..77487c6c 100644 --- a/test/backends/test_openai_ollama.py +++ b/test/backends/test_openai_ollama.py @@ -11,7 +11,7 @@ from mellea.backends.model_ids import META_LLAMA_3_2_1B from mellea.backends.openai import OpenAIBackend from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk, SimpleContext +from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext, SimpleContext @pytest.fixture(scope="module") @@ -36,7 +36,7 @@ def backend(gh_run: int): @pytest.fixture(scope="function") def m_session(backend): """Fresh OpenAI session for each test.""" - session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True)) + session = MelleaSession(backend, ctx=ChatContext()) yield session session.reset() @@ -147,11 +147,11 @@ class Email(pydantic.BaseModel): def test_async_parallel_requests(m_session): async def parallel_requests(): model_opts = {ModelOption.STREAM: True} - mot1 = m_session.backend.generate_from_context( + mot1, _ = m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext(), model_options=model_opts ) - mot2 = m_session.backend.generate_from_context( - CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts + mot2, _ = m_session.backend.generate_from_context( + CBlock("Say Goodbye!"),SimpleContext(), model_options=model_opts ) m1_val = None @@ -184,7 +184,7 @@ async def parallel_requests(): def test_async_avalue(m_session): async def avalue(): - mot1 = m_session.backend.generate_from_context( + mot1, _ = m_session.backend.generate_from_context( CBlock("Say Hello."), SimpleContext() ) m1_final_val = await mot1.avalue() diff --git a/test/backends/test_openai_vllm/test_openai_vllm.py b/test/backends/test_openai_vllm/test_openai_vllm.py index c83d9abe..7101bea7 100644 --- a/test/backends/test_openai_vllm/test_openai_vllm.py +++ b/test/backends/test_openai_vllm/test_openai_vllm.py @@ -1,6 +1,6 @@ # test/rits_backend_tests/test_openai_integration.py from mellea import MelleaSession -from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk +from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext from mellea.backends.openai import OpenAIBackend from mellea.backends.aloras.openai.granite_aloras import add_granite_aloras from mellea.stdlib.requirement import Requirement, ALoraRequirement, LLMaJRequirement @@ -34,7 +34,7 @@ class TestOpenAIBackend: base_url="http://0.0.0.0:8000/v1", api_key="EMPTY", ) - m = MelleaSession(backend, ctx=LinearContext(is_chat_context=True)) + m = MelleaSession(backend, ctx=ChatContext()) def test_instruct(self): self.m.reset() @@ -136,7 +136,7 @@ class TestOpenAIALoraStuff: base_url="http://localhost:8000/v1", api_key="EMPTY", ) - m = MelleaSession(backend, ctx=LinearContext()) + m = MelleaSession(backend, ctx=ChatContext()) add_granite_aloras(backend) def test_system_prompt(self): diff --git a/test/backends/test_watsonx.py b/test/backends/test_watsonx.py index 0a9b917a..12ec10d3 100644 --- a/test/backends/test_watsonx.py +++ b/test/backends/test_watsonx.py @@ -10,7 +10,7 @@ from mellea.backends.formatter import TemplateFormatter from mellea.backends.types import ModelOption from mellea.backends.watsonx import WatsonxAIBackend -from mellea.stdlib.base import CBlock, LinearContext, ModelOutputThunk, SimpleContext +from mellea.stdlib.base import CBlock, ModelOutputThunk, ChatContext, SimpleContext @pytest.fixture(scope="module") @@ -31,7 +31,7 @@ def session(backend: WatsonxAIBackend): pytest.skip("Skipping watsonx tests.") else: """Fresh Watson session for each test.""" - session = MelleaSession(backend, ctx=LinearContext(is_chat_context=True)) + session = MelleaSession(backend, ctx=ChatContext()) yield session session.reset() @@ -42,7 +42,6 @@ def test_instruct(session: MelleaSession): assert isinstance(result, ModelOutputThunk) assert "2" in result.value # type: ignore -@pytest.mark.xfail(reason="watsonx python sdk has weird interactions with event loops; causes some errors with pytest.") @pytest.mark.qualitative def test_multiturn(session: MelleaSession): session.instruct("What is the capital of France?") @@ -56,7 +55,6 @@ def test_chat(session): f"Expected a message with content containing 2 but found {output_message}" ) -@pytest.mark.xfail(reason="watsonx python sdk has weird interactions with event loops; causes some errors with pytest.") @pytest.mark.qualitative def test_format(session: MelleaSession): class Person(pydantic.BaseModel): @@ -105,8 +103,8 @@ def test_generate_from_raw(session: MelleaSession): def test_async_parallel_requests(session): async def parallel_requests(): model_opts = {ModelOption.STREAM: True} - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) - mot2 = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext(), model_options=model_opts) + mot2, _ = session.backend.generate_from_context(CBlock("Say Goodbye!"), SimpleContext(), model_options=model_opts) m1_val = None m2_val = None @@ -114,7 +112,7 @@ async def parallel_requests(): m1_val = await mot1.astream() if not mot2.is_computed(): m2_val = await mot2.astream() - + assert m1_val is not None, "should be a string val after generation" assert m2_val is not None, "should be a string val after generation" @@ -130,13 +128,10 @@ async def parallel_requests(): assert m2_final_val == mot2.value asyncio.run(parallel_requests()) -# TODO: If this becomes a big issue, we will just have to re-instantiate the ModelInference object between requests. -# Ideally, we would only do this when creating a new m.session from the same backend. -@pytest.mark.xfail(reason="watsonx python sdk apparently doesn't support running across multiple async event loops.") @pytest.mark.qualitative def test_async_avalue(session): async def avalue(): - mot1 = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) + mot1, _ = session.backend.generate_from_context(CBlock("Say Hello."), SimpleContext()) m1_final_val = await mot1.avalue() assert m1_final_val is not None assert m1_final_val == mot1.value diff --git a/test/stdlib_basics/test_base.py b/test/stdlib_basics/test_base.py index 3353742c..e19c6adc 100644 --- a/test/stdlib_basics/test_base.py +++ b/test/stdlib_basics/test_base.py @@ -1,5 +1,5 @@ import pytest -from mellea.stdlib.base import CBlock, Component, LinearContext +from mellea.stdlib.base import CBlock, Component def test_cblock(): @@ -27,27 +27,5 @@ def format_for_llm(self) -> str: assert len(c.parts()) == 0 -def test_context(): - ctx = LinearContext(window_size=3) - ctx.insert(CBlock("a")) - ctx.insert(CBlock("b")) - ctx.insert(CBlock("c")) - ctx.insert(CBlock("d")) - - -def test_actions_for_available_tools(): - ctx = LinearContext(window_size=3) - ctx.insert(CBlock("a")) - ctx.insert(CBlock("b")) - for_generation = ctx.render_for_generation() - assert for_generation is not None - - actions = ctx.actions_for_available_tools() - assert actions is not None - - assert len(for_generation) == len(actions) - for i in range(len(actions)): - assert actions[i] == for_generation[i] - if __name__ == "__main__": pytest.main([__file__]) diff --git a/test/stdlib_basics/test_base_context.py b/test/stdlib_basics/test_base_context.py new file mode 100644 index 00000000..17a9502b --- /dev/null +++ b/test/stdlib_basics/test_base_context.py @@ -0,0 +1,69 @@ +import pytest + +from mellea.stdlib.base import Context, CBlock, SimpleContext, ChatContext + + +def context_construction(cls: type[Context]): + tree0 = cls() + tree1 = tree0.add(CBlock("abc")) + assert tree1.previous_node == tree0 + + tree1a = tree0.add(CBlock("def")) + assert tree1a.previous_node == tree0 + + +def test_context_construction(): + context_construction(SimpleContext) + context_construction(ChatContext) + + +def large_context_construction(cls: type[Context]): + root = cls() + + full_graph: Context = root + for i in range(1000): + full_graph = full_graph.add(CBlock(f"abc{i}")) + + all_data = full_graph.as_list() + assert len(all_data) == 1000 + + +def test_large_context_construction(): + large_context_construction(SimpleContext) + large_context_construction(ChatContext) + + +def test_render_view_for_simple_context(): + ctx = SimpleContext() + for i in range(5): + ctx = ctx.add(CBlock(f"a {i}")) + assert len(ctx.as_list()) == 5, "Adding 5 items to context should result in 5 items" + assert len(ctx.view_for_generation()) == 0, "Render size should be 0 -- NO HISTORY for SimpleContext" + + +def test_render_view_for_chat_context(): + ctx = ChatContext(window_size=3) + for i in range(5): + ctx = ctx.add(CBlock(f"a {i}")) + assert len(ctx.as_list()) == 5, "Adding 5 items to context should result in 5 items" + assert len(ctx.view_for_generation()) == 3, "Render size should be 3" + + +def test_actions_for_available_tools(): + ctx = ChatContext(window_size=3) + ctx = ctx.add(CBlock("a")) + ctx = ctx.add(CBlock("b")) + + for_generation = ctx.view_for_generation() + assert for_generation is not None + + actions = ctx.actions_for_available_tools() + assert actions is not None + + assert len(for_generation) == len(actions) + for i in range(len(actions)): + assert actions[i] == for_generation[i] + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/test/stdlib_basics/test_chat_view.py b/test/stdlib_basics/test_chat_view.py index c56b7c2e..81caa6f8 100644 --- a/test/stdlib_basics/test_chat_view.py +++ b/test/stdlib_basics/test_chat_view.py @@ -1,7 +1,7 @@ import pytest -from mellea.stdlib.base import LinearContext, ModelOutputThunk +from mellea.stdlib.base import ChatContext, ModelOutputThunk from mellea.stdlib.chat import Message, as_chat_history from mellea.stdlib.session import start_session @@ -9,7 +9,7 @@ @pytest.fixture(scope="function") def linear_session(): """Session with linear context for chat tests.""" - session = start_session(ctx=LinearContext()) + session = start_session(ctx=ChatContext()) yield session session.reset() @@ -27,13 +27,16 @@ def test_chat_view_linear_ctx(linear_session): linear_session.chat("What is 2+2?") assert len(as_chat_history(linear_session.ctx)) == 4 assert all([type(x) == Message for x in as_chat_history(linear_session.ctx)]) + assert len(linear_session.ctx.view_for_generation()) == 4 -@pytest.mark.skip("linearize() returns [] for a SimpleContext... that's going to be annoying.") +# @pytest.mark.skip("linearize() returns [] for a SimpleContext... that's going to be annoying.") def test_chat_view_simple_ctx(simple_session): simple_session.chat("What is 1+1?") simple_session.chat("What is 2+2?") - assert len(as_chat_history(simple_session.ctx)) == 2 + assert len(as_chat_history(simple_session.ctx)) == 4 assert all([type(x) == Message for x in as_chat_history(simple_session.ctx)]) + assert len(simple_session.ctx.view_for_generation()) == 0 + if __name__ == "__main__": diff --git a/test/stdlib_basics/test_contextual_session.py b/test/stdlib_basics/test_contextual_session.py index d7a39e8c..a401f117 100644 --- a/test/stdlib_basics/test_contextual_session.py +++ b/test/stdlib_basics/test_contextual_session.py @@ -1,226 +1,227 @@ -from typing import Literal - -import pytest - -from mellea import chat, generative, instruct, query, start_session, transform, validate -from mellea.backends.model_ids import IBM_GRANITE_3_3_8B, META_LLAMA_3_2_1B -from mellea.stdlib.base import ModelOutputThunk -from mellea.stdlib.mify import MifiedProtocol, mify -from mellea.stdlib.requirement import req -from mellea.stdlib.session import MelleaSession, get_session - - -@pytest.fixture(scope="module") -def model_id(gh_run: int): - if gh_run == 1: - return META_LLAMA_3_2_1B - else: - return IBM_GRANITE_3_3_8B - - -@generative -def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... - - -@generative -def generate_summary(text: str) -> str: ... - - -@mify(fields_include={"name", "age"}) -class TestPerson: - def __init__(self, name: str, age: int): - self.name = name - self.age = age - - def get_info(self) -> str: - """Get person information.""" - return f"{self.name} is {self.age} years old" - - -def test_basic_contextual_session(model_id): - """Test basic contextual session usage with convenience functions.""" - with start_session(model_id=model_id): - # Test instruct - result = instruct("Say hello") - assert isinstance(result, ModelOutputThunk) - assert result.value is not None - - # Test that we can get the session - current_session = get_session() - assert isinstance(current_session, MelleaSession) - - -def test_no_active_session_error(): - """Test error handling when no session is active.""" - with pytest.raises(RuntimeError, match="No active session found"): - get_session() - - with pytest.raises(RuntimeError, match="No active session found"): - instruct("test") - - with pytest.raises(RuntimeError, match="No active session found"): - chat("test") - -@pytest.mark.qualitative -def test_generative_with_contextual_session(model_id): - """Test generative slots work with contextual sessions.""" - with start_session(model_id=model_id): - # Test without explicit session parameter - result = classify_sentiment(text="I love this!") - assert result in ["positive", "negative"] - - # Test with summary generation - summary = generate_summary(text="A short text about something interesting.") - assert isinstance(summary, str) - assert len(summary) > 0 - -@pytest.mark.qualitative -def test_generative_backward_compatibility(model_id): - """Test that generative slots still work with explicit session parameter.""" - with start_session(model_id=model_id) as m: - # Test old pattern still works - result = classify_sentiment(m, text="I love this!") - assert result in ["positive", "negative"] - - -def test_mify_with_contextual_session(model_id): - """Test mify functionality with contextual sessions.""" - with start_session(model_id=model_id): - person = TestPerson("Alice", 30) - assert isinstance(person, MifiedProtocol) - - # Test query functionality - query_result = query(person, "What is this person's name?") - assert isinstance(query_result, ModelOutputThunk) - - # Test transform functionality - transform_result = transform(person, "Make this person 5 years older") - # Transform can return either ModelOutputThunk or the tool output when tools are called - assert transform_result is not None - - -def test_nested_sessions(model_id): - """Test nested sessions behavior.""" - with start_session(model_id=model_id) as outer_session: - outer_result = instruct("outer session test") - assert isinstance(outer_result, ModelOutputThunk) - - with start_session(model_id=model_id) as inner_session: - # Inner session should be active - current_session = get_session() - assert current_session is inner_session - - inner_result = instruct("inner session test") - assert isinstance(inner_result, ModelOutputThunk) - - # After inner session exits, outer should be active again - current_session = get_session() - assert current_session is outer_session - - -def test_session_cleanup(model_id): - """Test session cleanup after context exit.""" - session_ref = None - with start_session(model_id=model_id) as m: - session_ref = m - instruct("test during session") - - # After exiting context, no session should be active - with pytest.raises(RuntimeError, match="No active session found"): - get_session() - - # Session should have been cleaned up - assert hasattr(session_ref, "ctx") - - -def test_all_convenience_functions(model_id): - """Test all convenience functions work within contextual session.""" - with start_session(model_id=model_id): - # Test instruct - instruct_result = instruct("Generate a greeting") - assert isinstance(instruct_result, ModelOutputThunk) - - # Test chat - chat_result = chat("Hello there") - assert hasattr(chat_result, "content") - - # Test validate - validation = validate([req("The response should be positive")]) - assert validation is not None - - # Test query with a mified object - test_person = TestPerson("Test", 42) - query_result = query(test_person, "What is the name?") - assert isinstance(query_result, ModelOutputThunk) - - # Test transform with a mified object - transform_result = transform(test_person, "Double the age") - assert transform_result is not None - - -def test_session_with_parameters(model_id): - """Test contextual session with custom parameters.""" - with start_session(backend_name="ollama", model_id=model_id) as m: - result = instruct("test with parameters") - assert isinstance(result, ModelOutputThunk) - assert isinstance(m, MelleaSession) - - -def test_multiple_sequential_sessions(model_id): - """Test multiple sequential contextual sessions.""" - # First session - with start_session(model_id=model_id): - result1 = instruct("first session") - assert isinstance(result1, ModelOutputThunk) - - # Ensure no session is active between contexts - with pytest.raises(RuntimeError, match="No active session found"): - get_session() - - # Second session - with start_session(model_id=model_id): - result2 = instruct("second session") - assert isinstance(result2, ModelOutputThunk) - - -def test_contextual_session_with_mified_object_methods(model_id): - """Test that mified objects work properly within contextual sessions.""" - with start_session(model_id=model_id): - person = TestPerson("Bob", 25) - - # Test that mified object methods work - query_obj = person.get_query_object("What's the age?") - assert query_obj is not None - - transform_obj = person.get_transform_object("Make older") - assert transform_obj is not None - - # Test format_for_llm - llm_format = person.format_for_llm() - assert llm_format is not None - assert hasattr(llm_format, "args") - - -def test_session_methods_with_mified_objects(model_id): - """Test using session query/transform methods with mified objects.""" - with start_session(model_id=model_id) as m: - person = TestPerson("Charlie", 35) - - # Test session query method - query_result = m.query(person, "What is this person's age?") - assert isinstance(query_result, ModelOutputThunk) - - # Test session transform method - transform_result = m.transform(person, "Make this person younger") - # Transform can return either ModelOutputThunk or tool output when tools are called - assert transform_result is not None - - # Verify mified objects have query/transform object creation methods - assert hasattr(person, "get_query_object") - assert hasattr(person, "get_transform_object") - assert hasattr(person, "_query_type") - assert hasattr(person, "_transform_type") - - -if __name__ == "__main__": - pytest.main([__file__]) +# TODO: needs to be rewritten +# from typing import Literal +# +# import pytest +# +# from mellea import chat, generative, instruct, query, start_session, transform, validate +# from mellea.backends.model_ids import IBM_GRANITE_3_3_8B, META_LLAMA_3_2_1B +# from mellea.stdlib.base import ModelOutputThunk +# from mellea.stdlib.mify import MifiedProtocol, mify +# from mellea.stdlib.requirement import req +# from mellea.stdlib.session import MelleaSession, get_session +# +# +# @pytest.fixture(scope="module") +# def model_id(gh_run: int): +# if gh_run == 1: +# return META_LLAMA_3_2_1B +# else: +# return IBM_GRANITE_3_3_8B +# +# +# @generative +# def classify_sentiment(text: str) -> Literal["positive", "negative"]: ... +# +# +# @generative +# def generate_summary(text: str) -> str: ... +# +# +# @mify(fields_include={"name", "age"}) +# class TestPerson: +# def __init__(self, name: str, age: int): +# self.name = name +# self.age = age +# +# def get_info(self) -> str: +# """Get person information.""" +# return f"{self.name} is {self.age} years old" +# +# +# def test_basic_contextual_session(model_id): +# """Test basic contextual session usage with convenience functions.""" +# with start_session(model_id=model_id): +# # Test instruct +# result = instruct("Say hello") +# assert isinstance(result, ModelOutputThunk) +# assert result.value is not None +# +# # Test that we can get the session +# current_session = get_session() +# assert isinstance(current_session, MelleaSession) +# +# +# def test_no_active_session_error(): +# """Test error handling when no session is active.""" +# with pytest.raises(RuntimeError, match="No active session found"): +# get_session() +# +# with pytest.raises(RuntimeError, match="No active session found"): +# instruct("test") +# +# with pytest.raises(RuntimeError, match="No active session found"): +# chat("test") +# +# @pytest.mark.qualitative +# def test_generative_with_contextual_session(model_id): +# """Test generative slots work with contextual sessions.""" +# with start_session(model_id=model_id): +# # Test without explicit session parameter +# result = classify_sentiment(text="I love this!") +# assert result in ["positive", "negative"] +# +# # Test with summary generation +# summary = generate_summary(text="A short text about something interesting.") +# assert isinstance(summary, str) +# assert len(summary) > 0 +# +# @pytest.mark.qualitative +# def test_generative_backward_compatibility(model_id): +# """Test that generative slots still work with explicit session parameter.""" +# with start_session(model_id=model_id) as m: +# # Test old pattern still works +# result = classify_sentiment(m, text="I love this!") +# assert result in ["positive", "negative"] +# +# +# def test_mify_with_contextual_session(model_id): +# """Test mify functionality with contextual sessions.""" +# with start_session(model_id=model_id): +# person = TestPerson("Alice", 30) +# assert isinstance(person, MifiedProtocol) +# +# # Test query functionality +# query_result = query(person, "What is this person's name?") +# assert isinstance(query_result, ModelOutputThunk) +# +# # Test transform functionality +# transform_result = transform(person, "Make this person 5 years older") +# # Transform can return either ModelOutputThunk or the tool output when tools are called +# assert transform_result is not None +# +# +# def test_nested_sessions(model_id): +# """Test nested sessions behavior.""" +# with start_session(model_id=model_id) as outer_session: +# outer_result = instruct("outer session test") +# assert isinstance(outer_result, ModelOutputThunk) +# +# with start_session(model_id=model_id) as inner_session: +# # Inner session should be active +# current_session = get_session() +# assert current_session is inner_session +# +# inner_result = instruct("inner session test") +# assert isinstance(inner_result, ModelOutputThunk) +# +# # After inner session exits, outer should be active again +# current_session = get_session() +# assert current_session is outer_session +# +# +# def test_session_cleanup(model_id): +# """Test session cleanup after context exit.""" +# session_ref = None +# with start_session(model_id=model_id) as m: +# session_ref = m +# instruct("test during session") +# +# # After exiting context, no session should be active +# with pytest.raises(RuntimeError, match="No active session found"): +# get_session() +# +# # Session should have been cleaned up +# assert hasattr(session_ref, "ctx") +# +# +# def test_all_convenience_functions(model_id): +# """Test all convenience functions work within contextual session.""" +# with start_session(model_id=model_id): +# # Test instruct +# instruct_result = instruct("Generate a greeting") +# assert isinstance(instruct_result, ModelOutputThunk) +# +# # Test chat +# chat_result = chat("Hello there") +# assert hasattr(chat_result, "content") +# +# # Test validate +# validation = validate([req("The response should be positive")]) +# assert validation is not None +# +# # Test query with a mified object +# test_person = TestPerson("Test", 42) +# query_result = query(test_person, "What is the name?") +# assert isinstance(query_result, ModelOutputThunk) +# +# # Test transform with a mified object +# transform_result = transform(test_person, "Double the age") +# assert transform_result is not None +# +# +# def test_session_with_parameters(model_id): +# """Test contextual session with custom parameters.""" +# with start_session(backend_name="ollama", model_id=model_id) as m: +# result = instruct("test with parameters") +# assert isinstance(result, ModelOutputThunk) +# assert isinstance(m, MelleaSession) +# +# +# def test_multiple_sequential_sessions(model_id): +# """Test multiple sequential contextual sessions.""" +# # First session +# with start_session(model_id=model_id): +# result1 = instruct("first session") +# assert isinstance(result1, ModelOutputThunk) +# +# # Ensure no session is active between contexts +# with pytest.raises(RuntimeError, match="No active session found"): +# get_session() +# +# # Second session +# with start_session(model_id=model_id): +# result2 = instruct("second session") +# assert isinstance(result2, ModelOutputThunk) +# +# +# def test_contextual_session_with_mified_object_methods(model_id): +# """Test that mified objects work properly within contextual sessions.""" +# with start_session(model_id=model_id): +# person = TestPerson("Bob", 25) +# +# # Test that mified object methods work +# query_obj = person.get_query_object("What's the age?") +# assert query_obj is not None +# +# transform_obj = person.get_transform_object("Make older") +# assert transform_obj is not None +# +# # Test format_for_llm +# llm_format = person.format_for_llm() +# assert llm_format is not None +# assert hasattr(llm_format, "args") +# +# +# def test_session_methods_with_mified_objects(model_id): +# """Test using session query/transform methods with mified objects.""" +# with start_session(model_id=model_id) as m: +# person = TestPerson("Charlie", 35) +# +# # Test session query method +# query_result = m.query(person, "What is this person's age?") +# assert isinstance(query_result, ModelOutputThunk) +# +# # Test session transform method +# transform_result = m.transform(person, "Make this person younger") +# # Transform can return either ModelOutputThunk or tool output when tools are called +# assert transform_result is not None +# +# # Verify mified objects have query/transform object creation methods +# assert hasattr(person, "get_query_object") +# assert hasattr(person, "get_transform_object") +# assert hasattr(person, "_query_type") +# assert hasattr(person, "_transform_type") +# +# +# if __name__ == "__main__": +# pytest.main([__file__]) diff --git a/test/stdlib_basics/test_funcs.py b/test/stdlib_basics/test_funcs.py new file mode 100644 index 00000000..f652eb98 --- /dev/null +++ b/test/stdlib_basics/test_funcs.py @@ -0,0 +1,37 @@ + + +import pytest + +from mellea.backends.types import ModelOption +from mellea.stdlib.base import CBlock +from mellea.stdlib.funcs import instruct +from mellea.stdlib.session import start_session + + +@pytest.fixture(scope="module") +def m_session(gh_run): + if gh_run == 1: + m = start_session( + "ollama", + model_id="llama3.2:1b", + model_options={ModelOption.MAX_NEW_TOKENS: 5}, + ) + else: + m = start_session( + "ollama", + model_id="granite3.3:8b", + model_options={ModelOption.MAX_NEW_TOKENS: 5}, + ) + yield m + del m + +def test_func_context(m_session): + initial_ctx = m_session.ctx + backend = m_session.backend + + out, ctx = instruct("Write a sentence.", initial_ctx, backend) + assert initial_ctx is not ctx + assert ctx._data is out + +if __name__ == "__main__": + pytest.main([__file__]) \ No newline at end of file diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index d33c14e4..ebcace55 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -1,7 +1,6 @@ import pytest from typing import Literal from mellea import generative, start_session -from mellea.stdlib.base import LinearContext @generative diff --git a/test/stdlib_basics/test_reqlib_markdown.py b/test/stdlib_basics/test_reqlib_markdown.py index 3d7ef903..82445b9a 100644 --- a/test/stdlib_basics/test_reqlib_markdown.py +++ b/test/stdlib_basics/test_reqlib_markdown.py @@ -1,13 +1,13 @@ import pytest -from mellea.stdlib.base import CBlock, ModelOutputThunk, LinearContext, Context +from mellea.stdlib.base import CBlock, ModelOutputThunk, Context, ChatContext from mellea.stdlib.reqlib.md import is_markdown_list, is_markdown_table, as_markdown_list from mellea.stdlib.requirement import default_output_to_bool def from_model(s: str) -> Context: - ctx = LinearContext() - ctx.insert(ModelOutputThunk(value=s, meta={"test": True})) + ctx = ChatContext() + ctx = ctx.add(ModelOutputThunk(value=s, meta={"test": True})) return ctx diff --git a/test/stdlib_basics/test_requirement.py b/test/stdlib_basics/test_requirement.py index 5d11b00a..a1bef684 100644 --- a/test/stdlib_basics/test_requirement.py +++ b/test/stdlib_basics/test_requirement.py @@ -1,11 +1,11 @@ import asyncio import pytest -from mellea.stdlib.base import LinearContext, ModelOutputThunk +from mellea.stdlib.base import ChatContext, ModelOutputThunk from mellea.stdlib.requirement import Requirement, simple_validate from mellea.stdlib.session import start_session -ctx = LinearContext() -ctx.insert(ModelOutputThunk("test")) +ctx = ChatContext() +ctx = ctx.add(ModelOutputThunk("test")) def test_llmaj_validation_req_output_field(): m = start_session(ctx=ctx) @@ -39,4 +39,4 @@ def test_simple_validate_invalid(): val_result = validation_func(ctx) if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__]) diff --git a/test/stdlib_basics/test_sampling_ctx.py b/test/stdlib_basics/test_sampling_ctx.py index 2e0ba033..0e78e7ea 100644 --- a/test/stdlib_basics/test_sampling_ctx.py +++ b/test/stdlib_basics/test_sampling_ctx.py @@ -1,6 +1,7 @@ import pytest -from mellea import LinearContext, start_session +from mellea import start_session from mellea.backends import ModelOption +from mellea.stdlib.base import ChatContext from mellea.stdlib.sampling import ( MultiTurnStrategy, RejectionSamplingStrategy, @@ -10,7 +11,7 @@ class TestSamplingCtxCase: m = start_session( - model_options={ModelOption.MAX_NEW_TOKENS: 100}, ctx=LinearContext() + model_options={ModelOption.MAX_NEW_TOKENS: 100}, ctx=ChatContext() ) def _run_asserts_for_ctx_testing(self, res): @@ -27,12 +28,9 @@ def _run_asserts_for_ctx_testing(self, res): assert len(res.sample_validations[0]) == 3, ( "there should be 3 validation results." ) - assert len(self.m.ctx._ctx) == 2, ( - "there should only be a message and a response in the ctx." - ) def test_ctx_for_rejection_sampling(self): - self.m.ctx.reset() + self.m.reset() res = self.m.instruct( "Write a sentence.", requirements=[ @@ -44,10 +42,13 @@ def test_ctx_for_rejection_sampling(self): return_sampling_results=True, ) self._run_asserts_for_ctx_testing(res) + assert len(self.m.ctx.as_list()) == 2, ( + "there should only be a message and a response in the ctx." + ) assert len(self.m.last_prompt()) == 1, "Last prompt should only have only one instruction inside - independent of sampling iterations." def test_ctx_for_multiturn(self): - self.m.ctx.reset() + self.m.reset() res = self.m.instruct( "Write a sentence.", requirements=[ @@ -60,7 +61,9 @@ def test_ctx_for_multiturn(self): ) self._run_asserts_for_ctx_testing(res) - + assert len(self.m.ctx.as_list()) >= 2, ( + "there should be at least a message and a response in the ctx; more if the first result failed validation" + ) assert len(self.m.last_prompt()) == len(res.sample_generations)*2-1, "For n sampling iterations there should be 2n-1 prompt conversation elements in the last prompt." if __name__ == "__main__": diff --git a/test/stdlib_basics/test_session.py b/test/stdlib_basics/test_session.py index 9caa8d6f..67168a38 100644 --- a/test/stdlib_basics/test_session.py +++ b/test/stdlib_basics/test_session.py @@ -31,9 +31,11 @@ def test_start_session_openai_with_kwargs(gh_run): base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1", api_key="ollama", ) + initial_ctx = m.ctx response = m.instruct("testing") assert isinstance(response, ModelOutputThunk) assert response.value is not None + assert initial_ctx is not m.ctx if __name__ == "__main__": diff --git a/test/stdlib_basics/test_session_ctx.py b/test/stdlib_basics/test_session_ctx.py deleted file mode 100644 index 17f8acd6..00000000 --- a/test/stdlib_basics/test_session_ctx.py +++ /dev/null @@ -1,67 +0,0 @@ -from mellea.stdlib.base import ( - CBlock, - GenerateLog, - ModelOutputThunk, - ContextTurn, - LinearContext, - SimpleContext, - Context, -) - - -def run_on_context_1(ctx: Context): - ctx.insert(CBlock("abc"), generate_logs=[GenerateLog()]) - o, l = ctx.last_output_and_logs() - assert ( - o is None - ), "There is only a Cblock in the context, not an output (ModelOutputThunk). Shouldn't return anything" - assert l is None, "If there is no output, there should be no corresponding log" - - -def run_on_context_2(ctx: Context): - ctx.insert(ModelOutputThunk("def"), generate_logs=[GenerateLog(), GenerateLog()]) - o, l = ctx.last_output_and_logs(all_intermediate_results=True) - assert o is not None - assert isinstance(l, list) - assert len(l) == 2 - assert isinstance(l[0], GenerateLog) - - -def run_on_context_3(ctx: Context): - for is_final in (True, False): - ctx.insert_turn( - ContextTurn(None, ModelOutputThunk("def")), - generate_logs=[GenerateLog(is_final_result=is_final)], - ) - o, l = ctx.last_output_and_logs() - print(f"o={o}, l={l}") - assert o is not None - assert isinstance(l, GenerateLog) - - -def test_ctx_single_log(): - ctx = SimpleContext() - run_on_context_1(ctx) - run_on_context_2(ctx) - run_on_context_3(ctx) - - -def test_ctx_multi_log(): - ctx = LinearContext() - run_on_context_1(ctx) - run_on_context_2(ctx) - run_on_context_3(ctx) - - -def test_ctx_overlap(): - ctx = SimpleContext() - run_on_context_1(ctx) - ctx = LinearContext() - run_on_context_1(ctx) - - ctx2 = SimpleContext() - last_logs = ctx.get_logs_by_index(-1) - assert isinstance(last_logs, list) - assert len(last_logs) == 1 - assert isinstance(last_logs[0], GenerateLog) - run_on_context_1(ctx2) diff --git a/test/stdlib_basics/test_vision_ollama.py b/test/stdlib_basics/test_vision_ollama.py index 6b2936e5..d0c0ed1c 100644 --- a/test/stdlib_basics/test_vision_ollama.py +++ b/test/stdlib_basics/test_vision_ollama.py @@ -69,8 +69,9 @@ def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.I assert "yes" in instr.value.lower() or "no" in instr.value.lower() # make sure you get the last action - _, log = m_session.ctx.last_output_and_logs() - last_action = log.action + turn = m_session.ctx.last_turn() + assert turn is not None + last_action = turn.model_input assert isinstance(last_action, Instruction) assert len(last_action._images) > 0 @@ -79,7 +80,7 @@ def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.I assert image0 == image_block # get prompt message - lp = log.prompt + lp = turn.output._generate_log.prompt assert isinstance(lp, list) assert len(lp) == 1 @@ -111,8 +112,9 @@ def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, g assert "yes" in ct.content.lower() or "no" in ct.content.lower() # make sure you get the last action - _, log = m_session.ctx.last_output_and_logs() - last_action = log.action + turn = m_session.ctx.last_turn() + assert turn is not None + last_action = turn.model_input assert isinstance(last_action, Message) assert len(last_action.images) > 0 @@ -121,7 +123,7 @@ def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, g assert image0_str == ImageBlock.from_pil_image(pil_image)._value # get prompt message - lp = log.prompt + lp = turn.output._generate_log.prompt assert isinstance(lp, list) assert len(lp) == 1 diff --git a/test/stdlib_basics/test_vision_openai.py b/test/stdlib_basics/test_vision_openai.py index da49fa76..385a8ffe 100644 --- a/test/stdlib_basics/test_vision_openai.py +++ b/test/stdlib_basics/test_vision_openai.py @@ -73,8 +73,9 @@ def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.I assert "yes" in instr.value.lower() or "no" in instr.value.lower() # make sure you get the last action - _, log = m_session.ctx.last_output_and_logs() - last_action = log.action + turn = m_session.ctx.last_turn() + assert turn is not None + last_action = turn.model_input assert isinstance(last_action, Instruction) assert len(last_action._images) > 0 @@ -83,7 +84,7 @@ def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.I assert image0 == image_block # get prompt message - lp = log.prompt + lp = turn.output._generate_log.prompt assert isinstance(lp, list) assert len(lp) == 1 @@ -123,8 +124,9 @@ def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, g assert "yes" in ct.content.lower() or "no" in ct.content.lower() # make sure you get the last action - _, log = m_session.ctx.last_output_and_logs() - last_action = log.action + turn = m_session.ctx.last_turn() + assert turn is not None + last_action = turn.model_input assert isinstance(last_action, Message) assert len(last_action.images) > 0 @@ -133,7 +135,7 @@ def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, g assert image0_str == ImageBlock.from_pil_image(pil_image)._value # get prompt message - lp = log.prompt + lp = turn.output._generate_log.prompt assert isinstance(lp, list) assert len(lp) == 1 diff --git a/test/test_formatter_baseclasses.py b/test/test_formatter_baseclasses.py index 1c91ce55..2532846f 100644 --- a/test/test_formatter_baseclasses.py +++ b/test/test_formatter_baseclasses.py @@ -9,15 +9,9 @@ from mellea.backends.formatter import TemplateFormatter from mellea.backends.model_ids import ModelIdentifier, IBM_GRANITE_3_2_8B from mellea.stdlib.base import ( - BasicContext, CBlock, Component, - Context, - ContextTurn, - GenerateLog, - LinearContext, ModelOutputThunk, - SimpleContext, TemplateRepresentation, ) from mellea.stdlib.chat import Message @@ -117,52 +111,6 @@ def __init__(self, msg: Message) -> None: ), "parse should set the result object to result.parsed_repr if it's not parsing a message" -def test_print_context(tf: TemplateFormatter): - ctx = LinearContext() - with pytest.raises(AssertionError): - tf.print_context(ctx) - - ctx = LinearContext(is_chat_context=False) - ctx._ctx = [CBlock("1"), CBlock("2")] - output = tf.print_context(ctx) - assert type(output) == str - assert output == "12" - - with pytest.raises( - Exception, match="Do not know how to handle a SimpleContext yet." - ): - st_ctx = SimpleContext() - st_ctx.is_chat_context = False - tf.print_context(st_ctx) - - class _TestContext(BasicContext): - def reset(self): - pass - - def insert(self, value, *, key=None, generate_logs: list[GenerateLog] | None = None): - pass - - def insert_turn(self, turn, *, generate_logs: list[GenerateLog] | None = None,): - pass - - def copy(self) -> Context: - return self - - def _hash_for_kv_cache(self): - pass - - def render_for_generation(self) -> Optional[List[Component | CBlock]]: - pass - - def last_output(self) -> ModelOutputThunk | None: - pass - - def last_turn(self) -> ContextTurn | None: - pass - - with pytest.raises(Exception): - tf.print_context(_TestContext()) - def test_custom_template_string(tf: TemplateFormatter): class _TemplInstruction(Instruction): diff --git a/test/test_tool_calls.py b/test/test_tool_calls.py index f2643cc6..11c47a3c 100644 --- a/test/test_tool_calls.py +++ b/test/test_tool_calls.py @@ -4,16 +4,16 @@ from mellea.backends.ollama import OllamaModelBackend from mellea.backends.tools import add_tools_from_context_actions, add_tools_from_model_options from mellea.backends.types import ModelOption -from mellea.stdlib.base import CBlock, Component, ModelOutputThunk, TemplateRepresentation +from mellea.stdlib.base import CBlock, Component, ModelOutputThunk, TemplateRepresentation, ChatContext from mellea.stdlib.docs.richdocument import Table -from mellea.stdlib.session import LinearContext, MelleaSession +from mellea.stdlib.session import MelleaSession @pytest.fixture(scope="module") def m() -> MelleaSession: return MelleaSession( backend=OllamaModelBackend(), - ctx=LinearContext(), + ctx=ChatContext(), ) @@ -32,10 +32,10 @@ def table() -> Table: def test_tool_called_from_context_action(m: MelleaSession, table: Table): """Make sure tools can be called from actions in the context.""" - m.ctx.reset() + m.reset() # Insert a component with tools into the context. - m.ctx.insert(table) + m.ctx = m.ctx.add(table) # Create fake tools. def test1(): ... @@ -57,7 +57,7 @@ def test2(): ... def test_tool_called(m: MelleaSession, table: Table): """We don't force tools to be called. As a result, this test might unexpectedly fail.""" r = 10 - m.ctx.reset() + m.reset() returned_tool = False for i in range(r): @@ -72,7 +72,7 @@ def test_tool_called(m: MelleaSession, table: Table): def test_tool_not_called(m: MelleaSession, table: Table): """Ensure tools aren't always called when provided.""" r = 10 - m.ctx.reset() + m.reset() returned_no_tool = False for i in range(r):