diff --git a/docs/examples/instruct_validate_repair/sampling_with_prefix_strategy.py b/docs/examples/instruct_validate_repair/sampling_with_prefix_strategy.py new file mode 100644 index 00000000..dd56a9c7 --- /dev/null +++ b/docs/examples/instruct_validate_repair/sampling_with_prefix_strategy.py @@ -0,0 +1,30 @@ +from mellea.backends import ModelOption +from mellea.stdlib.requirement import check, req, simple_validate +from mellea.stdlib.sampling.prefix_cached import RejectionSamplingStrategyWithPrefix + +requirements = [ + req("The email should have a salutation"), # == r1 + req( + "Use only lower-case letters", + validation_fn=simple_validate(lambda x: x.lower() == x), + ), # == r2 + check("Do not mention purple elephants."), # == r3 + req("The email should be funny."), +] + +import mellea # noqa: E402 + + +def write_email(m: mellea.MelleaSession, name: str, notes: str) -> str: + email_candidate = m.instruct( + "Write an email to {{name}} using the notes following: {{notes}}.", + requirements=requirements, + strategy=RejectionSamplingStrategyWithPrefix(loop_budget=5), + user_variables={"name": name, "notes": notes}, + return_sampling_results=True, + ) + + if email_candidate.success: + return str(email_candidate.result) + else: + return email_candidate.sample_generations[0].value diff --git a/docs/examples/mcp/README.md b/docs/examples/mcp/README.md index 8202f789..48161c33 100644 --- a/docs/examples/mcp/README.md +++ b/docs/examples/mcp/README.md @@ -14,7 +14,7 @@ uv pip install "mcp[cli]" and run the example in MCP debug UI: ```bash -uv run mcp dev docs/examples/tutorial/mcp_example.py +uv run mcp dev docs/examples/mcp/mcp_example.py ``` diff --git a/mellea/stdlib/funcs.py b/mellea/stdlib/funcs.py index 5990edb2..b0a6cbb1 100644 --- a/mellea/stdlib/funcs.py +++ b/mellea/stdlib/funcs.py @@ -263,9 +263,9 @@ def chat( def validate( reqs: Requirement | list[Requirement], - context: Context, backend: Backend, *, + context: Context | None = None, output: CBlock | None = None, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -701,9 +701,9 @@ async def achat( async def avalidate( reqs: Requirement | list[Requirement], - context: Context, backend: Backend, *, + context: Context | None = None, output: CBlock | None = None, format: type[BaseModelSubclass] | None = None, model_options: dict | None = None, @@ -715,7 +715,12 @@ async def avalidate( reqs = [reqs] if not isinstance(reqs, list) else reqs reqs = [Requirement(req) if type(req) is str else req for req in reqs] + assert (context is not None) != (output is not None), ( + "Either context or output must be provided. Not both." + ) + if output is None: + assert context is not None validation_target_ctx = context else: validation_target_ctx = SimpleContext() diff --git a/mellea/stdlib/reqlib/md.py b/mellea/stdlib/reqlib/md.py index 3cee2770..d5d2539f 100644 --- a/mellea/stdlib/reqlib/md.py +++ b/mellea/stdlib/reqlib/md.py @@ -3,7 +3,7 @@ import mistletoe from mellea.stdlib.base import Context -from mellea.stdlib.requirement import Requirement +from mellea.stdlib.requirement import Requirement, ValidationResult, simple_validate # region lists @@ -25,8 +25,9 @@ def as_markdown_list(ctx: Context) -> list[str] | None: return None -def _md_list(ctx: Context): - return as_markdown_list(ctx) is not None +async def _md_list(ctx: Context) -> ValidationResult: + re = as_markdown_list(ctx) is not None + return ValidationResult(re) is_markdown_list = Requirement( @@ -40,11 +41,10 @@ def _md_list(ctx: Context): # region tables -def _md_table(ctx: Context): - raw_output = ctx.last_output() +def _md_table(raw_output: str) -> bool: assert raw_output is not None try: - parsed = mistletoe.Document(raw_output.value) + parsed = mistletoe.Document(raw_output) if len(parsed.children) != 1: return False return type(parsed.children[0]) is mistletoe.block_token.Table @@ -54,8 +54,7 @@ def _md_table(ctx: Context): is_markdown_table = Requirement( description="The output should be formatted as a Markdown table.", - validation_fn=_md_table, + validation_fn=simple_validate(_md_table), ) - # endregion diff --git a/mellea/stdlib/requirement.py b/mellea/stdlib/requirement.py index f10a3aaf..1679b675 100644 --- a/mellea/stdlib/requirement.py +++ b/mellea/stdlib/requirement.py @@ -2,7 +2,7 @@ import inspect import re -from collections.abc import Callable +from collections.abc import Awaitable, Callable from copy import copy from typing import Any, overload @@ -93,7 +93,7 @@ class Requirement(Component): def __init__( self, description: str | None = None, - validation_fn: Callable[[Context], ValidationResult] | None = None, + validation_fn: Callable[[Context], Awaitable[ValidationResult]] | None = None, *, output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool, check_only: bool = False, @@ -127,7 +127,7 @@ async def validate( """Chooses the appropriate validation strategy and applies that strategy.""" if self.validation_fn is not None: # Python validation strategy - return self.validation_fn(ctx) + return await self.validation_fn(ctx) else: # LLMaJ validation strategy. This includes ALora because the backend generate call will appropriately dispatch. assert self.output_to_bool is not None @@ -197,7 +197,7 @@ class ScorerRequirement(Requirement): def __init__( self, description: str | None = None, - validation_fn: Callable[[Context], ValidationResult] | None = None, + validation_fn: Callable[[Context], Awaitable[ValidationResult]] | None = None, preference_ordering: str = "max", *, output_to_bool: Callable[[CBlock | str], bool] | None = default_output_to_bool, @@ -234,7 +234,7 @@ async def validate( """Chooses the appropriate validation strategy and applies that strategy. Asserts that the returned ValidationResult has a valid score.""" if self.validation_fn is not None: # Python validation strategy - validation_result = self.validation_fn(ctx) + validation_result = await self.validation_fn(ctx) assert validation_result._score is not None, ( "ScorerRequirement must have a score that is not None" ) @@ -292,18 +292,18 @@ def check(*args, **kwargs) -> Requirement: @overload def simple_validate( fn: Callable[[str], tuple[bool, str]], -) -> Callable[[Context], ValidationResult]: ... +) -> Callable[[Context], Awaitable[ValidationResult]]: ... @overload def simple_validate( fn: Callable[[str], bool], *, reason: str | None = None -) -> Callable[[Context], ValidationResult]: ... +) -> Callable[[Context], Awaitable[ValidationResult]]: ... def simple_validate( fn: Callable[[str], Any], *, reason: str | None = None -) -> Callable[[Context], ValidationResult]: +) -> Callable[[Context], Awaitable[ValidationResult]]: """Syntactic sugar for writing validation functions that only operate over the last output from the model (interpreted as a string). This is useful when your validation logic only depends upon the most recent model output. For example: @@ -321,7 +321,7 @@ def simple_validate( reason: only used if the provided function returns a bool; if the validation function fails, a static reason for that failure to give to the llm when repairing """ - def validate(ctx: Context) -> ValidationResult: + async def validate(ctx: Context) -> ValidationResult: o = ctx.last_output() if o is None or o.value is None: FancyLogger.get_logger().warn( diff --git a/mellea/stdlib/sampling/base.py b/mellea/stdlib/sampling/base.py index 6520a7da..e0d51925 100644 --- a/mellea/stdlib/sampling/base.py +++ b/mellea/stdlib/sampling/base.py @@ -165,7 +165,6 @@ async def sample( reqs=reqs, context=result_ctx, backend=backend, - output=result, format=format, model_options=model_options, # tool_calls=tool_calls # Don't support using tool calls in validation strategies. diff --git a/mellea/stdlib/sampling/best_of_n.py b/mellea/stdlib/sampling/best_of_n.py index 59c402b2..5dbb7e1d 100644 --- a/mellea/stdlib/sampling/best_of_n.py +++ b/mellea/stdlib/sampling/best_of_n.py @@ -131,7 +131,6 @@ async def sample( reqs=reqs, context=result_ctx, backend=backend, - output=result, format=format, model_options=model_options, input=next_action._description, # type: ignore diff --git a/mellea/stdlib/sampling/prefix_cached.py b/mellea/stdlib/sampling/prefix_cached.py new file mode 100644 index 00000000..3317c1b5 --- /dev/null +++ b/mellea/stdlib/sampling/prefix_cached.py @@ -0,0 +1,85 @@ +"""Sampling Strategy that uses prefix caching idea based on two turn chats.""" + +from collections.abc import Awaitable, Callable + +from mellea.backends import Backend, BaseModelSubclass, ModelOption +from mellea.stdlib.base import ChatContext, Component, Context, ContextTurn +from mellea.stdlib.chat import Message +from mellea.stdlib.requirement import Requirement, ValidationResult +from mellea.stdlib.sampling import RejectionSamplingStrategy, SamplingResult + + +class RejectionSamplingStrategyWithPrefix(RejectionSamplingStrategy): + """Rejection Sampling class that uses the last turn as prefix cache for requirement checking.""" + + async def sample( + self, + action: Component, + context: Context, + backend: Backend, + requirements: list[Requirement] | None, + *, + validation_ctx: Context | None = None, + format: type[BaseModelSubclass] | None = None, + model_options: dict | None = None, + tool_calls: bool = False, + show_progress: bool = True, + ) -> SamplingResult: + """Sample method inherited from RejectionSamplingStrategy.""" + reqs: list[Requirement] = [] + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + reqs = list(set(reqs)) + + def make_val( + req_string: str, + ) -> Callable[[Context], Awaitable[ValidationResult]]: + async def validate_agentic(ctx: Context) -> ValidationResult: + lt = ctx.last_turn() + assert isinstance(lt, ContextTurn) + assert lt.model_input is not None + assert lt.output is not None + + chat_ctx = ChatContext() + chat_ctx = chat_ctx.add(lt.model_input) + chat_ctx = chat_ctx.add(lt.output) + + action = Message( + role="user", + content=f"Does the output fulfill the requirement? Answer only with yes or no. Requirement: '{req_string}'", + ) + + llm_as_a_judge_result, _ = backend.generate_from_context( + action, + chat_ctx, + format=format, + model_options={ModelOption.MAX_NEW_TOKENS: 10}, + ) + value = await llm_as_a_judge_result.avalue() + + return ValidationResult( + result=value.lower().startswith("yes"), + reason=value, + thunk=llm_as_a_judge_result, + ) + + return validate_agentic + + for req in reqs: + if req.validation_fn is None: + req.validation_fn = make_val(str(req.description)) + + res = await super().sample( + action=action, + context=context, + backend=backend, + requirements=reqs, + validation_ctx=validation_ctx, + format=format, + model_options=model_options, + tool_calls=tool_calls, + show_progress=show_progress, + ) + return res diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 2a63a71a..665d4f2b 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -446,8 +446,8 @@ def validate( """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" return mfuncs.validate( reqs=reqs, - context=self.ctx, backend=self.backend, + context=self.ctx if output is None else None, output=output, format=format, model_options=model_options, @@ -730,7 +730,7 @@ async def avalidate( """Validates a set of requirements over the output (if provided) or the current context (if the output is not provided).""" return await mfuncs.avalidate( reqs=reqs, - context=self.ctx, + context=self.ctx if output is None else None, backend=self.backend, output=output, format=format, diff --git a/test/stdlib_basics/test_funcs.py b/test/stdlib_basics/test_funcs.py index 4e99afc5..61763ed2 100644 --- a/test/stdlib_basics/test_funcs.py +++ b/test/stdlib_basics/test_funcs.py @@ -62,12 +62,10 @@ async def test_ainstruct(m_session): assert ctx._data is out async def test_avalidate(m_session): - initial_ctx = m_session.ctx backend = m_session.backend val_result = await avalidate( reqs=[req("Be formal."), req("Avoid telling jokes.")], - context=initial_ctx, backend=backend, output=ModelOutputThunk("Here is an output.") ) @@ -77,4 +75,4 @@ async def test_avalidate(m_session): if __name__ == "__main__": - pytest.main([__file__]) \ No newline at end of file + pytest.main([__file__])