diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/instruction.py index cca98930..e7b5f885 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/instruction.py @@ -1,5 +1,7 @@ """Instructions.""" +from __future__ import annotations + from copy import deepcopy import jinja2 @@ -106,6 +108,7 @@ def __init__( self._output_prefix = ( blockify(output_prefix) if output_prefix is not None else None ) + self._repair_string: str | None = None def parts(self): """Returns all of the constituent parts of an Instruction.""" @@ -132,6 +135,7 @@ def format_for_llm(self) -> TemplateRepresentation: "output_prefix": ( self._output_prefix if self._output_prefix is not None else None ), + "repair": self._repair_string, }, tools=None, template_order=["*", "Instruction"], @@ -147,3 +151,9 @@ def apply_user_dict_from_jinja(user_dict: dict[str, str], s: str) -> str: def requirements(self) -> list[Requirement]: """Returns a list of Requirement instances.""" return self._requirements + + def copy_and_repair(self, repair_string: str) -> Instruction: + """Creates a copy of the instruction and adds/overwrites the repair string.""" + res = deepcopy(self) + res._repair_string = repair_string + return res diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index c5ef33ce..651d77a2 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -2,12 +2,22 @@ import abc from collections.abc import Callable +from copy import deepcopy from typing import Any import tqdm +from mellea import LinearContext from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, GenerateLog, ModelOutputThunk +from mellea.stdlib.base import ( + CBlock, + Component, + Context, + ContextTurn, + GenerateLog, + ModelOutputThunk, +) +from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ValidationResult @@ -23,6 +33,7 @@ def __init__( sample_generations: list[ModelOutputThunk] | None = None, sample_validations: list[list[tuple[Requirement, ValidationResult]]] | None = None, + sample_actions: list[Component] | None = None, ): """Initialize a new instance of sampling results. @@ -47,56 +58,52 @@ class SamplingStrategy(abc.ABC): """ # the function signature here matches that of m.validate - validate: Callable[[list[Requirement], Any], list[ValidationResult]] | None = None + validate: ( + Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None + ) = None generate: ( - Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None ) = None @abc.abstractmethod def sample( self, - instruction: Instruction, + action: Component, + context: Context, + requirements: list[Requirement], *, generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, ) -> SamplingResult: """This method is the abstract method for sampling a given instruction. It must be implemented by any concrete subclasses to provide specific sampling logic. Args: - instruction (Instruction): The instruction object to be sampled. + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + requirements: The requirements to be used by the sampling strategy (merged with global requirements). generate_logs: Optional list of GenerateLog objects. If None, no collection happens. + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. """ -class RejectionSamplingStrategy(SamplingStrategy): - """Sampling strategy that rejects samples based on given instructions.""" +class BaseSamplingStrategy(SamplingStrategy): + """Base class for multiple strategies that rejects samples based on given instructions.""" + + loop_budget: int def __init__( self, *, loop_budget: int = 1, - repair: Callable[ - [ - Instruction, - list[tuple[Requirement, ValidationResult]], - list[Instruction], - ], - Instruction, - ] = lambda i, r, h_i: i, - select_from_failure: Callable[ - [ - Instruction, - list[ModelOutputThunk], - list[list[tuple[Requirement, ValidationResult]]], - ], - ModelOutputThunk, - ] = lambda _, results, __: results[0], - validate: Callable[[list[Requirement], Any], list[ValidationResult]] + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None = None, generate: ( - Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None ) = None, requirements: list[Requirement] | None = None, ): @@ -104,8 +111,6 @@ def __init__( Args: loop_budget: Number of times to iterate through the process. Must be greater than 0. - repair: Function to apply "repairs" to an instruction based on its requirements and validation results. - select_from_failure: Function to select a model output thunk from failed attempts. validate: Function to validate the results against requirements. If None, validation is provided later through setter. generate: Function to generate new model output thunks. If None, generate is provided later through setter. requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. @@ -114,26 +119,72 @@ def __init__( AssertionError: If loop_budget is not greater than 0. """ assert loop_budget > 0, "Loop budget must be at least 1." + self.loop_budget = loop_budget - self.repair = repair - self.select_from_failure = select_from_failure self.validate = validate # it's ok to be None here self.generate = generate # it's ok to be None here self.requirements = requirements + @staticmethod + @abc.abstractmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + """ + Repair function that is being invoked if not all requirements are fulfilled. It should return a next action component. + + Args: + ctx: The context to be passed to the sampling strategy. + past_actions: List of actions that have been executed (without success). + past_results: List of (unsuccessful) generation results for these actions. + past_val: List of validation results for the results. + + Returns: + The next action component. + """ + ... + + @staticmethod + @abc.abstractmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ): + """This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success. + + Args: + sampled_actions: List of actions that have been executed (without success). + sampled_results: List of (unsuccessful) generation results for these actions. + sampled_val: List of validation results for the results. + + Returns: + The index of the result that should be selected as `.value`. + """ + ... + def sample( self, - instruction: Instruction, + action: Component, + context: Context, + requirements: list[Requirement], *, show_progress: bool = True, generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. Args: - instruction: The Instruction object containing the instruction to generate a valid model output thunk. - show_progress: if true, a tqdm progress bar is used. Otherwise messages will still be sent to flog. + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. generate_logs: If provided, the generations will be logged. + requirements: List of requirements to test against (merged with global requirements). + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -141,79 +192,191 @@ def sample( Raises: AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. """ - assert self.repair is not None, "Repair must be provided." - assert self.select_from_failure is not None, ( - "Select from failure must be provided." - ) assert self.validate is not None, "Validation must be provided." assert self.generate is not None, "Generate must be provided." + # just to be sure to not cause issues to the OG context + ctx = context.copy() + validation_ctx = validation_ctx if validation_ctx is not None else context + assert validation_ctx is not None, "Validation context must be provided." + flog = FancyLogger.get_logger() + sampled_results: list[ModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + # The `logging_redirect_tqdm` approach did not work, so instead we will use the show_progress # flag to determine whether we should show the pbar. show_progress = show_progress and flog.getEffectiveLevel() <= FancyLogger.INFO - failed_results: list[ModelOutputThunk] = [] - failed_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - failed_instructions: list[Instruction] = [] + reqs = [] + # global requirements supersede local requirements (global requiremenst can be defined by user) + # Todo: re-evaluate if this makes sense + if self.requirements is not None: + reqs += self.requirements + elif requirements is not None: + reqs += requirements + reqs = list(set(reqs)) loop_count = 0 - loop_budget_range_iterator = ( - tqdm.tqdm(range(self.loop_budget)) + tqdm.tqdm(range(self.loop_budget)) # type: ignore if show_progress - else range(self.loop_budget) + else range(self.loop_budget) # type: ignore ) + + new_action = deepcopy(action) for _ in loop_budget_range_iterator: # type: ignore loop_count += 1 if not show_progress: flog.info(f"Running loop {loop_count} of {self.loop_budget}") - # run a pass + # run a generation pass + result = self.generate(new_action, ctx, generate_logs) - result = self.generate(instruction, generate_logs) + # validation pass + val_scores = self.validate(reqs, validation_ctx, result) - if self.requirements is not None: - reqs = self.requirements - else: - reqs = instruction.requirements - val_scores = self.validate(reqs, result) + # match up reqs with scores constraint_scores = list(zip(reqs, val_scores)) - failed_results.append(result) - failed_scores.append(constraint_scores) - failed_instructions.append(instruction) + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(new_action) + # if all vals are true -- break and return success if all(bool(s[1]) for s in constraint_scores): flog.info("SUCCESS") return SamplingResult( result, success=True, - sample_generations=failed_results, - sample_validations=failed_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, ) else: + # log partial success and continue count_valid = len([s for s in constraint_scores if bool(s[1])]) flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + # If we did not pass all constraints, update the instruction and try again. - instruction = self.repair( - instruction, constraint_scores, failed_instructions + new_action = self.repair( + ctx, sampled_actions, sampled_results, sampled_scores ) flog.info( - f"Invoking select_from_failure after {len(failed_results)} failed attempts." + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." ) - best_failed_result = self.select_from_failure( - instruction, failed_results, failed_scores + + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores ) - assert best_failed_result in failed_results, ( + assert best_failed_index < len(sampled_results), ( "The select_from_failure method did not return a valid result. It has to selected from failed_results." ) return SamplingResult( - best_failed_result, + sampled_results[best_failed_index], success=False, - sample_generations=failed_results, - sample_validations=failed_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, ) + + +class RejectionSamplingStrategy(BaseSamplingStrategy): + """Simple rejection sampling strategy that just repeats the same call on failure.""" + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + # simply returns the first attempt if all loops fail + return 0 + + @staticmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + # repeat the last action again. + return past_actions[-1] + + +class RepairTemplateStrategy(BaseSamplingStrategy): + """A sampling strategy that adds a repair string to the instruction object.""" + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> int: + # simply returns the first attempt if all loops fail + return 0 + + @staticmethod + def repair( + ctx: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + pa = past_actions[-1] + if isinstance(pa, Instruction): + last_failed_reqs: list[Requirement] = [ + s[0] for s in past_val[-1] if not s[1] + ] + last_failed_reqs_str = "* " + "\n* ".join( + [str(r.description) for r in last_failed_reqs] + ) + return pa.copy_and_repair( + repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}" + ) + return past_actions[-1] + + +class MultiTurnStrategy(BaseSamplingStrategy): + """Rejection sampling strategy with (agentic) multi-turn repair.""" + + @staticmethod + def select_from_failure( + sampled_actions: list[Component], + sampled_results: list[ModelOutputThunk], + sampled_val: list[list[tuple[Requirement, ValidationResult]]], + ): + # return the last assistant message even if all attempts of repair failed. + return -1 + + @staticmethod + def repair( + context: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + assert isinstance(context, LinearContext), ( + " Need linear context to run agentic sampling." + ) + + # add failed execution to chat history + context.insert_turn(ContextTurn(past_actions[-1], past_results[-1])) + + last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]] + last_failed_reqs_str = "* " + "\n* ".join( + [str(r.description) for r in last_failed_reqs] + ) + # TODO: what to do with checks ?? + + next_action = Message( + role="user", + content=f"The following requirements have not been met: \n{last_failed_reqs_str}\n Please try again to fulfill the requirements.", + ) + + return next_action diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 0250b768..e3b08d1b 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -293,15 +293,17 @@ def instruct( generate_logs[0].is_final_result = True else: if strategy.validate is None: - strategy.validate = lambda reqs, output: self.validate( # type: ignore + strategy.validate = lambda reqs, val_ctx, output: self.validate( # type: ignore reqs, output=output, # type: ignore ) # type: ignore if strategy.generate is None: strategy.generate = ( - lambda instruction, g_logs: self.backend.generate_from_context( + lambda instruction, + gen_ctx, + g_logs: self.backend.generate_from_context( instruction, - ctx=self.ctx, + ctx=gen_ctx, format=format, model_options=model_options, generate_logs=g_logs, @@ -310,7 +312,9 @@ def instruct( ) # sample - res = strategy.sample(i, generate_logs=generate_logs) + res = strategy.sample( + i, self.ctx, i.requirements, generate_logs=generate_logs + ) # make sure that one Log is marked as the one related to res.result if res.success: diff --git a/mellea/templates/prompts/default/Instruction.jinja2 b/mellea/templates/prompts/default/Instruction.jinja2 index c576b806..52d0766e 100644 --- a/mellea/templates/prompts/default/Instruction.jinja2 +++ b/mellea/templates/prompts/default/Instruction.jinja2 @@ -37,8 +37,14 @@ Here is some grounding context: {%- endif -%} {% endblock grounding_context %} +{%- block repair_block -%} +{% if repair %} +{{ repair -}} +{%- endif -%} +{% endblock repair_block %} + {%- block output_prefix -%} {% if output_prefix %} {{ output_prefix -}} {%- endif -%} -{% endblock output_prefix %} \ No newline at end of file +{% endblock output_prefix %} diff --git a/mellea/templates/prompts/granite/Instruction.jinja2 b/mellea/templates/prompts/granite/Instruction.jinja2 index ea9851e9..a2c79d6a 100644 --- a/mellea/templates/prompts/granite/Instruction.jinja2 +++ b/mellea/templates/prompts/granite/Instruction.jinja2 @@ -37,8 +37,14 @@ Here are some examples of what the response might look like: {%- endif -%} {% endblock icl_examples %} +{%- block repair_block -%} +{% if repair %} +{{ repair -}} +{%- endif -%} +{% endblock repair_block %} + {%- block output_prefix -%} {% if output_prefix %} {{ output_prefix -}} {%- endif -%} -{% endblock output_prefix %} \ No newline at end of file +{% endblock output_prefix %} diff --git a/test/stdlib_basics/test_sampling_ctx.py b/test/stdlib_basics/test_sampling_ctx.py new file mode 100644 index 00000000..f178aae4 --- /dev/null +++ b/test/stdlib_basics/test_sampling_ctx.py @@ -0,0 +1,63 @@ +from mellea import LinearContext, start_session +from mellea.backends import ModelOption +from mellea.stdlib.sampling import ( + MultiTurnStrategy, + RejectionSamplingStrategy, + SamplingResult, +) + + +class TestSamplingCtxCase: + m = start_session( + model_options={ModelOption.MAX_NEW_TOKENS: 100}, ctx=LinearContext() + ) + + def _run_asserts_for_ctx_testing(self, res): + assert isinstance(res, SamplingResult), "res should be a SamplingResult." + + assert isinstance(res.value, str), "Value should be set and a string." + + assert len(res.sample_generations) >= 1, ( + "sample generation should have at least one sample." + ) + assert len(res.sample_validations) >= 1, ( + "sample validation should have at least one sample." + ) + assert len(res.sample_validations[0]) == 3, ( + "there should be 3 validation results." + ) + assert len(self.m.ctx._ctx) == 2, ( + "there should only be a message and a response in the ctx." + ) + + def test_ctx_for_rejection_sampling(self): + self.m.ctx.reset() + res = self.m.instruct( + "Write a sentence.", + requirements=[ + "be funny", + "be formal", + "use only words starting with the letter w", + ], + strategy=RejectionSamplingStrategy(loop_budget=3), + return_sampling_results=True, + ) + self._run_asserts_for_ctx_testing(res) + assert len(self.m.last_prompt()) == 1, "Last prompt should only have only one instruction inside - independent of sampling iterations." + + def test_ctx_for_multiturn(self): + self.m.ctx.reset() + res = self.m.instruct( + "Write a sentence.", + requirements=[ + "be funny", + "be formal", + "use only words starting with the letter w", + ], + strategy=MultiTurnStrategy(loop_budget=3), + return_sampling_results=True, + ) + + self._run_asserts_for_ctx_testing(res) + + assert len(self.m.last_prompt()) == len(res.sample_generations)*2-1, "For n sampling iterations there should be 2n-1 prompt conversation elements in the last prompt."