From e4b7cd4cea7cbdbfd47f05da4e250cea614cf4ed Mon Sep 17 00:00:00 2001 From: Hendrik Strobelt Date: Thu, 14 Aug 2025 16:58:53 +0200 Subject: [PATCH 1/5] preparing for new sampling: adding a repair field to Instruction that can only be set by `.copy_and_repair(...)`. The templates are updated to accommodate the new field. --- mellea/stdlib/instruction.py | 10 ++++++++++ mellea/templates/prompts/default/Instruction.jinja2 | 8 +++++++- mellea/templates/prompts/granite/Instruction.jinja2 | 8 +++++++- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/instruction.py index cca98930..e7b5f885 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/instruction.py @@ -1,5 +1,7 @@ """Instructions.""" +from __future__ import annotations + from copy import deepcopy import jinja2 @@ -106,6 +108,7 @@ def __init__( self._output_prefix = ( blockify(output_prefix) if output_prefix is not None else None ) + self._repair_string: str | None = None def parts(self): """Returns all of the constituent parts of an Instruction.""" @@ -132,6 +135,7 @@ def format_for_llm(self) -> TemplateRepresentation: "output_prefix": ( self._output_prefix if self._output_prefix is not None else None ), + "repair": self._repair_string, }, tools=None, template_order=["*", "Instruction"], @@ -147,3 +151,9 @@ def apply_user_dict_from_jinja(user_dict: dict[str, str], s: str) -> str: def requirements(self) -> list[Requirement]: """Returns a list of Requirement instances.""" return self._requirements + + def copy_and_repair(self, repair_string: str) -> Instruction: + """Creates a copy of the instruction and adds/overwrites the repair string.""" + res = deepcopy(self) + res._repair_string = repair_string + return res diff --git a/mellea/templates/prompts/default/Instruction.jinja2 b/mellea/templates/prompts/default/Instruction.jinja2 index c576b806..52d0766e 100644 --- a/mellea/templates/prompts/default/Instruction.jinja2 +++ b/mellea/templates/prompts/default/Instruction.jinja2 @@ -37,8 +37,14 @@ Here is some grounding context: {%- endif -%} {% endblock grounding_context %} +{%- block repair_block -%} +{% if repair %} +{{ repair -}} +{%- endif -%} +{% endblock repair_block %} + {%- block output_prefix -%} {% if output_prefix %} {{ output_prefix -}} {%- endif -%} -{% endblock output_prefix %} \ No newline at end of file +{% endblock output_prefix %} diff --git a/mellea/templates/prompts/granite/Instruction.jinja2 b/mellea/templates/prompts/granite/Instruction.jinja2 index ea9851e9..a2c79d6a 100644 --- a/mellea/templates/prompts/granite/Instruction.jinja2 +++ b/mellea/templates/prompts/granite/Instruction.jinja2 @@ -37,8 +37,14 @@ Here are some examples of what the response might look like: {%- endif -%} {% endblock icl_examples %} +{%- block repair_block -%} +{% if repair %} +{{ repair -}} +{%- endif -%} +{% endblock repair_block %} + {%- block output_prefix -%} {% if output_prefix %} {{ output_prefix -}} {%- endif -%} -{% endblock output_prefix %} \ No newline at end of file +{% endblock output_prefix %} From f02e88b2e28d7c7cbf3b6a91bc920c488a68495a Mon Sep 17 00:00:00 2001 From: Hendrik Strobelt Date: Fri, 15 Aug 2025 12:51:45 +0200 Subject: [PATCH 2/5] new signature for RejectionSampling --- mellea/stdlib/sampling.py | 125 ++++++++++++++++++++++++-------------- mellea/stdlib/session.py | 10 +-- 2 files changed, 87 insertions(+), 48 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index c1dd7538..30464175 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -2,12 +2,13 @@ import abc from collections.abc import Callable +from copy import deepcopy from typing import Any import tqdm from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, GenerateLog, ModelOutputThunk +from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ValidationResult @@ -23,6 +24,7 @@ def __init__( sample_generations: list[ModelOutputThunk] | None = None, sample_validations: list[list[tuple[Requirement, ValidationResult]]] | None = None, + sample_actions: list[Component] | None = None, ): """Initialize a new instance of sampling results. @@ -47,56 +49,67 @@ class SamplingStrategy(abc.ABC): """ # the function signature here matches that of m.validate - validate: Callable[[list[Requirement], Any], list[ValidationResult]] | None = None + validate: ( + Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None + ) = None generate: ( - Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None ) = None @abc.abstractmethod def sample( self, - instruction: Instruction, + action: Component, + context: Context, *, generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, ) -> SamplingResult: """This method is the abstract method for sampling a given instruction. It must be implemented by any concrete subclasses to provide specific sampling logic. Args: - instruction (Instruction): The instruction object to be sampled. + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. generate_logs: Optional list of GenerateLog objects. If None, no collection happens. + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. """ class RejectionSamplingStrategy(SamplingStrategy): """Sampling strategy that rejects samples based on given instructions.""" + loop_budget: int + def __init__( self, *, loop_budget: int = 1, repair: Callable[ [ - Instruction, + Component, + Context, list[tuple[Requirement, ValidationResult]], - list[Instruction], + list[Component], ], - Instruction, - ] = lambda i, r, h_i: i, + Component, + ] = lambda i, c, r, h_i: i, select_from_failure: Callable[ [ - Instruction, + list[Component], list[ModelOutputThunk], list[list[tuple[Requirement, ValidationResult]]], ], - ModelOutputThunk, - ] = lambda _, results, __: results[0], - validate: Callable[[list[Requirement], Any], list[ValidationResult]] + int, + ] = lambda _, results, __: 0, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None = None, generate: ( - Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None ) = None, requirements: list[Requirement] | None = None, ): @@ -123,17 +136,23 @@ def __init__( def sample( self, - instruction: Instruction, + action: Component, + context: Context, *, show_progress: bool = True, generate_logs: list[GenerateLog] | None = None, + requirements: list[Requirement] | None = None, + validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. Args: - instruction: The Instruction object containing the instruction to generate a valid model output thunk. - show_progress: if true, a tqdm progress bar is used. Otherwise messages will still be sent to flog. + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. generate_logs: If provided, the generations will be logged. + requirements: List of requirements to test against. + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -148,68 +167,86 @@ def sample( assert self.validate is not None, "Validation must be provided." assert self.generate is not None, "Generate must be provided." + # just to be sure to not cause issues to the OG context + ctx = context.copy() + validation_ctx = validation_ctx if validation_ctx is not None else context + assert validation_ctx is not None, "Validation context must be provided." + flog = FancyLogger.get_logger() - failed_results: list[ModelOutputThunk] = [] - failed_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - failed_instructions: list[Instruction] = [] + sampled_results: list[ModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] - loop_count = 0 + if self.requirements is not None: + reqs = self.requirements + if requirements is not None: + flog.warn("Some requirements are ignored.") + else: + reqs = requirements if requirements is not None else [] + loop_count = 0 loop_budget_range_iterator = ( - tqdm.tqdm(range(self.loop_budget)) + tqdm.tqdm(range(self.loop_budget)) # type: ignore if show_progress - else range(self.loop_budget) + else range(self.loop_budget) # type: ignore ) + + new_action = deepcopy(action) for _ in loop_budget_range_iterator: # type: ignore loop_count += 1 if not show_progress: flog.info(f"Running loop {loop_count} of {self.loop_budget}") - # run a pass + # run a generation pass + result = self.generate(new_action, ctx, generate_logs) - result = self.generate(instruction, generate_logs) + # validation pass + val_scores = self.validate(reqs, validation_ctx, result) - if self.requirements is not None: - reqs = self.requirements - else: - reqs = instruction.requirements - val_scores = self.validate(reqs, result) + # match up reqs with scores constraint_scores = list(zip(reqs, val_scores)) - failed_results.append(result) - failed_scores.append(constraint_scores) - failed_instructions.append(instruction) + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(new_action) + # if all vals are true -- break and return success if all(bool(s[1]) for s in constraint_scores): flog.info("SUCCESS") return SamplingResult( result, success=True, - sample_generations=failed_results, - sample_validations=failed_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, ) else: + # log partial success and continue count_valid = len([s for s in constraint_scores if bool(s[1])]) flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + # If we did not pass all constraints, update the instruction and try again. - instruction = self.repair( - instruction, constraint_scores, failed_instructions + new_action = self.repair( + new_action, ctx, constraint_scores, sampled_actions ) flog.info( - f"Invoking select_from_failure after {len(failed_results)} failed attempts." + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." ) - best_failed_result = self.select_from_failure( - instruction, failed_results, failed_scores + + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores ) - assert best_failed_result in failed_results, ( + assert best_failed_index < len(sampled_results), ( "The select_from_failure method did not return a valid result. It has to selected from failed_results." ) return SamplingResult( - best_failed_result, + sampled_results[best_failed_index], success=False, - sample_generations=failed_results, - sample_validations=failed_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, ) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index cc0d3378..280ade4e 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -199,15 +199,17 @@ def instruct( generate_logs[0].is_final_result = True else: if strategy.validate is None: - strategy.validate = lambda reqs, output: self.validate( # type: ignore + strategy.validate = lambda reqs, val_ctx, output: self.validate( # type: ignore reqs, output=output, # type: ignore ) # type: ignore if strategy.generate is None: strategy.generate = ( - lambda instruction, g_logs: self.backend.generate_from_context( + lambda instruction, + gen_ctx, + g_logs: self.backend.generate_from_context( instruction, - ctx=self.ctx, + ctx=gen_ctx, format=format, model_options=model_options, generate_logs=g_logs, @@ -216,7 +218,7 @@ def instruct( ) # sample - res = strategy.sample(i, generate_logs=generate_logs) + res = strategy.sample(i, self.ctx, generate_logs=generate_logs) # make sure that one Log is marked as the one related to res.result if res.success: From c9de1f9d0dc5e17da964e65ec5cb50d8cd23b531 Mon Sep 17 00:00:00 2001 From: Hendrik Strobelt Date: Fri, 15 Aug 2025 13:32:15 +0200 Subject: [PATCH 3/5] adding RejectionSampling and AgenticSampling as subclasses of BaseSampling --- mellea/stdlib/sampling.py | 147 +++++++++++++++++++++++++++++++++++--- 1 file changed, 139 insertions(+), 8 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 30464175..fc3653f4 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -7,8 +7,17 @@ import tqdm +from mellea import LinearContext from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk +from mellea.stdlib.base import ( + CBlock, + Component, + Context, + ContextTurn, + GenerateLog, + ModelOutputThunk, +) +from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ValidationResult @@ -79,8 +88,8 @@ def sample( """ -class RejectionSamplingStrategy(SamplingStrategy): - """Sampling strategy that rejects samples based on given instructions.""" +class BaseSamplingStrategy(SamplingStrategy): + """Base class for multiple strategies that rejects samples based on given instructions.""" loop_budget: int @@ -90,13 +99,14 @@ def __init__( loop_budget: int = 1, repair: Callable[ [ - Component, Context, - list[tuple[Requirement, ValidationResult]], list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], ], Component, - ] = lambda i, c, r, h_i: i, + ] + | None, select_from_failure: Callable[ [ list[Component], @@ -104,7 +114,8 @@ def __init__( list[list[tuple[Requirement, ValidationResult]]], ], int, - ] = lambda _, results, __: 0, + ] + | None, validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None = None, generate: ( @@ -127,6 +138,9 @@ def __init__( AssertionError: If loop_budget is not greater than 0. """ assert loop_budget > 0, "Loop budget must be at least 1." + assert repair is not None, "Repair must be provided." + assert select_from_failure is not None, "Select from failure must be provided." + self.loop_budget = loop_budget self.repair = repair self.select_from_failure = select_from_failure @@ -229,7 +243,7 @@ def sample( # If we did not pass all constraints, update the instruction and try again. new_action = self.repair( - new_action, ctx, constraint_scores, sampled_actions + ctx, sampled_actions, sampled_results, sampled_scores ) flog.info( @@ -250,3 +264,120 @@ def sample( sample_validations=sampled_scores, sample_actions=sampled_actions, ) + + +class RejectionSamplingStrategy(BaseSamplingStrategy): + """Simple rejection sampling strategy with optional repair.""" + + def __init__( + self, + *, + loop_budget: int = 1, + repair: Callable[ + [ + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + Component, + ] = lambda past_actions, past_results, past_val: past_actions[-1], + select_from_failure: Callable[ + [ + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + int, + ] = lambda past_actions, past_results, past_val: 0, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + requirements: list[Requirement] | None = None, + ): + def repair_wrapper(_, past_actions, past_results, past_val): + return repair(past_actions, past_results, past_val) + + super().__init__( + loop_budget=loop_budget, + repair=repair_wrapper, + select_from_failure=select_from_failure, + validate=validate, + generate=generate, + requirements=requirements, + ) + + +class AgenticSamplingStrategy(BaseSamplingStrategy): + """Rejection sampling strategy with agentic (multi-turn) repair.""" + + def __init__( + self, + *, + loop_budget: int = 1, + repair: Callable[ + [ + Context, + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + Component, + ] + | None = None, + select_from_failure: Callable[ + [ + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + int, + ] = lambda past_actions, past_results, past_val: len(past_actions) - 1, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + requirements: list[Requirement] | None = None, + ): + if repair is None: + repair = AgenticSamplingStrategy.agentic_repair_default + + super().__init__( + loop_budget=loop_budget, + repair=repair, + select_from_failure=select_from_failure, + validate=validate, + generate=generate, + requirements=requirements, + ) + + @staticmethod + def agentic_repair_default( + context: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + assert isinstance(context, LinearContext), ( + " Need linear context to run agentic sampling." + ) + + # add failed execution to chat history + context.insert_turn(ContextTurn(past_actions[-1], past_results[-1])) + + last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]] + last_failed_reqs_str = "* " + "\n* ".join( + [str(r.description) for r in last_failed_reqs] + ) + # TODO: what to do with checks ?? + + next_action = Message( + role="user", + content=f"The following requirements have not been met: \n{last_failed_reqs_str}\n Please try again to fulfill the requirements.", + ) + + return next_action From decab140c9b92c6881ed381501470c7f7a0c2504 Mon Sep 17 00:00:00 2001 From: Hendrik Strobelt Date: Thu, 28 Aug 2025 16:52:11 +0200 Subject: [PATCH 4/5] fixing requirements and adding tests --- mellea/stdlib/sampling.py | 18 ++++--- mellea/stdlib/session.py | 4 +- test/stdlib_basics/test_sampling_ctx.py | 63 +++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 8 deletions(-) create mode 100644 test/stdlib_basics/test_sampling_ctx.py diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index fc3653f4..8df9ffbe 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -72,6 +72,7 @@ def sample( self, action: Component, context: Context, + requirements: list[Requirement], *, generate_logs: list[GenerateLog] | None = None, validation_ctx: Context | None = None, @@ -83,6 +84,7 @@ def sample( Args: action : The action object to be sampled. context: The context to be passed to the sampling strategy. + requirements: The requirements to be used by the sampling strategy (merged with global requirements). generate_logs: Optional list of GenerateLog objects. If None, no collection happens. validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. """ @@ -152,10 +154,10 @@ def sample( self, action: Component, context: Context, + requirements: list[Requirement], *, show_progress: bool = True, generate_logs: list[GenerateLog] | None = None, - requirements: list[Requirement] | None = None, validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. @@ -165,7 +167,7 @@ def sample( context: The context to be passed to the sampling strategy. show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. generate_logs: If provided, the generations will be logged. - requirements: List of requirements to test against. + requirements: List of requirements to test against (merged with global requirements). validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. Returns: @@ -192,12 +194,14 @@ def sample( sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] sampled_actions: list[Component] = [] + reqs = [] + # global requirements supersede local requirements (global requiremenst can be defined by user) + # Todo: re-evaluate if this makes sense if self.requirements is not None: - reqs = self.requirements - if requirements is not None: - flog.warn("Some requirements are ignored.") - else: - reqs = requirements if requirements is not None else [] + reqs += self.requirements + elif requirements is not None: + reqs += requirements + reqs = list(set(reqs)) loop_count = 0 loop_budget_range_iterator = ( diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 280ade4e..fcdd2131 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -218,7 +218,9 @@ def instruct( ) # sample - res = strategy.sample(i, self.ctx, generate_logs=generate_logs) + res = strategy.sample( + i, self.ctx, i.requirements, generate_logs=generate_logs + ) # make sure that one Log is marked as the one related to res.result if res.success: diff --git a/test/stdlib_basics/test_sampling_ctx.py b/test/stdlib_basics/test_sampling_ctx.py new file mode 100644 index 00000000..b6a5e9bb --- /dev/null +++ b/test/stdlib_basics/test_sampling_ctx.py @@ -0,0 +1,63 @@ +from mellea import LinearContext, start_session +from mellea.backends import ModelOption +from mellea.stdlib.sampling import ( + AgenticSamplingStrategy, + RejectionSamplingStrategy, + SamplingResult, +) + + +class TestSamplingCtxCase: + m = start_session( + model_options={ModelOption.MAX_NEW_TOKENS: 100}, ctx=LinearContext() + ) + + def _run_asserts_for_ctx_testing(self, res): + assert isinstance(res, SamplingResult), "res should be a SamplingResult." + + assert isinstance(res.value, str), "Value should be set and a string." + + assert len(res.sample_generations) >= 1, ( + "sample generation should have at least one sample." + ) + assert len(res.sample_validations) >= 1, ( + "sample validation should have at least one sample." + ) + assert len(res.sample_validations[0]) == 3, ( + "there should be 3 validation results." + ) + assert len(self.m.ctx._ctx) == 2, ( + "there should only be a message and a response in the ctx." + ) + + def test_ctx_for_rejection_sampling(self): + self.m.ctx.reset() + res = self.m.instruct( + "Write a sentence.", + requirements=[ + "be funny", + "be formal", + "use only words starting with the letter w", + ], + strategy=RejectionSamplingStrategy(loop_budget=3), + return_sampling_results=True, + ) + self._run_asserts_for_ctx_testing(res) + assert len(self.m.last_prompt()) == 1, "Last prompt should only have only one instruction inside - independent of sampling iterations." + + def test_ctx_for_agentic(self): + self.m.ctx.reset() + res = self.m.instruct( + "Write a sentence.", + requirements=[ + "be funny", + "be formal", + "use only words starting with the letter w", + ], + strategy=AgenticSamplingStrategy(loop_budget=3), + return_sampling_results=True, + ) + + self._run_asserts_for_ctx_testing(res) + + assert len(self.m.last_prompt()) == len(res.sample_generations)*2-1, "For n sampling iterations there should be 2n-1 prompt conversation elements in the last prompt." From 570dd7fdbb1c678e10bc69a25c05442369e64b93 Mon Sep 17 00:00:00 2001 From: Hendrik Strobelt Date: Thu, 28 Aug 2025 17:40:26 +0200 Subject: [PATCH 5/5] refactoring repair and select-from-failure as abtsract methods. --- mellea/stdlib/sampling.py | 215 ++++++++++++------------ test/stdlib_basics/test_sampling_ctx.py | 6 +- 2 files changed, 106 insertions(+), 115 deletions(-) diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index 8df9ffbe..8f72c1a0 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -99,25 +99,6 @@ def __init__( self, *, loop_budget: int = 1, - repair: Callable[ - [ - Context, - list[Component], - list[ModelOutputThunk], - list[list[tuple[Requirement, ValidationResult]]], - ], - Component, - ] - | None, - select_from_failure: Callable[ - [ - list[Component], - list[ModelOutputThunk], - list[list[tuple[Requirement, ValidationResult]]], - ], - int, - ] - | None, validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None = None, generate: ( @@ -130,8 +111,6 @@ def __init__( Args: loop_budget: Number of times to iterate through the process. Must be greater than 0. - repair: Function to apply "repairs" to an instruction based on its requirements and validation results. - select_from_failure: Function to select a model output thunk from failed attempts. validate: Function to validate the results against requirements. If None, validation is provided later through setter. generate: Function to generate new model output thunks. If None, generate is provided later through setter. requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. @@ -140,16 +119,53 @@ def __init__( AssertionError: If loop_budget is not greater than 0. """ assert loop_budget > 0, "Loop budget must be at least 1." - assert repair is not None, "Repair must be provided." - assert select_from_failure is not None, "Select from failure must be provided." self.loop_budget = loop_budget - self.repair = repair - self.select_from_failure = select_from_failure self.validate = validate # it's ok to be None here self.generate = generate # it's ok to be None here self.requirements = requirements + @staticmethod + @abc.abstractmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + """ + Repair function that is being invoked if not all requirements are fulfilled. It should return a next action component. + + Args: + ctx: The context to be passed to the sampling strategy. + past_actions: List of actions that have been executed (without success). + past_results: List of (unsuccessful) generation results for these actions. + past_val: List of validation results for the results. + + Returns: + The next action component. + """ + ... + + @staticmethod + @abc.abstractmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ): + """This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success. + + Args: + sampled_actions: List of actions that have been executed (without success). + sampled_results: List of (unsuccessful) generation results for these actions. + sampled_val: List of validation results for the results. + + Returns: + The index of the result that should be selected as `.value`. + """ + ... + def sample( self, action: Component, @@ -176,10 +192,6 @@ def sample( Raises: AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. """ - assert self.repair is not None, "Repair must be provided." - assert self.select_from_failure is not None, ( - "Select from failure must be provided." - ) assert self.validate is not None, "Validation must be provided." assert self.generate is not None, "Generate must be provided." @@ -271,96 +283,75 @@ def sample( class RejectionSamplingStrategy(BaseSamplingStrategy): - """Simple rejection sampling strategy with optional repair.""" + """Simple rejection sampling strategy that just repeats the same call on failure.""" - def __init__( - self, - *, - loop_budget: int = 1, - repair: Callable[ - [ - list[Component], - list[ModelOutputThunk], - list[list[tuple[Requirement, ValidationResult]]], - ], - Component, - ] = lambda past_actions, past_results, past_val: past_actions[-1], - select_from_failure: Callable[ - [ - list[Component], - list[ModelOutputThunk], - list[list[tuple[Requirement, ValidationResult]]], - ], - int, - ] = lambda past_actions, past_results, past_val: 0, - validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] - | None = None, - generate: ( - Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] - | None - ) = None, - requirements: list[Requirement] | None = None, - ): - def repair_wrapper(_, past_actions, past_results, past_val): - return repair(past_actions, past_results, past_val) - - super().__init__( - loop_budget=loop_budget, - repair=repair_wrapper, - select_from_failure=select_from_failure, - validate=validate, - generate=generate, - requirements=requirements, - ) + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + # simply returns the first attempt if all loops fail + return 0 + @staticmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + # repeat the last action again. + return past_actions[-1] -class AgenticSamplingStrategy(BaseSamplingStrategy): - """Rejection sampling strategy with agentic (multi-turn) repair.""" - def __init__( - self, - *, - loop_budget: int = 1, - repair: Callable[ - [ - Context, - list[Component], - list[ModelOutputThunk], - list[list[tuple[Requirement, ValidationResult]]], - ], - Component, - ] - | None = None, - select_from_failure: Callable[ - [ - list[Component], - list[ModelOutputThunk], - list[list[tuple[Requirement, ValidationResult]]], - ], - int, - ] = lambda past_actions, past_results, past_val: len(past_actions) - 1, - validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] - | None = None, - generate: ( - Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] - | None - ) = None, - requirements: list[Requirement] | None = None, +class RepairTemplateStrategy(BaseSamplingStrategy): + """A sampling strategy that adds a repair string to the instruction object.""" + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + # simply returns the first attempt if all loops fail + return 0 + + @staticmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + pa = past_actions[-1] + if isinstance(pa, Instruction): + last_failed_reqs: list[Requirement] = [ + s[0] for s in past_val[-1] if not s[1] + ] + last_failed_reqs_str = "* " + "\n* ".join( + [str(r.description) for r in last_failed_reqs] + ) + return pa.copy_and_repair( + repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" + ) + return past_actions[-1] + + +class MultiTurnStrategy(BaseSamplingStrategy): + """Rejection sampling strategy with (agentic) multi-turn repair.""" + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], ): - if repair is None: - repair = AgenticSamplingStrategy.agentic_repair_default - - super().__init__( - loop_budget=loop_budget, - repair=repair, - select_from_failure=select_from_failure, - validate=validate, - generate=generate, - requirements=requirements, - ) + # return the last assistant message even if all attempts of repair failed. + return -1 @staticmethod - def agentic_repair_default( + def repair( context: Context, past_actions: list[Component], past_results: list[ModelOutputThunk], diff --git a/test/stdlib_basics/test_sampling_ctx.py b/test/stdlib_basics/test_sampling_ctx.py index b6a5e9bb..f178aae4 100644 --- a/test/stdlib_basics/test_sampling_ctx.py +++ b/test/stdlib_basics/test_sampling_ctx.py @@ -1,7 +1,7 @@ from mellea import LinearContext, start_session from mellea.backends import ModelOption from mellea.stdlib.sampling import ( - AgenticSamplingStrategy, + MultiTurnStrategy, RejectionSamplingStrategy, SamplingResult, ) @@ -45,7 +45,7 @@ def test_ctx_for_rejection_sampling(self): self._run_asserts_for_ctx_testing(res) assert len(self.m.last_prompt()) == 1, "Last prompt should only have only one instruction inside - independent of sampling iterations." - def test_ctx_for_agentic(self): + def test_ctx_for_multiturn(self): self.m.ctx.reset() res = self.m.instruct( "Write a sentence.", @@ -54,7 +54,7 @@ def test_ctx_for_agentic(self): "be formal", "use only words starting with the letter w", ], - strategy=AgenticSamplingStrategy(loop_budget=3), + strategy=MultiTurnStrategy(loop_budget=3), return_sampling_results=True, )