diff --git a/mellea/stdlib/instruction.py b/mellea/stdlib/instruction.py index cca98930..e7b5f885 100644 --- a/mellea/stdlib/instruction.py +++ b/mellea/stdlib/instruction.py @@ -1,5 +1,7 @@ """Instructions.""" +from __future__ import annotations + from copy import deepcopy import jinja2 @@ -106,6 +108,7 @@ def __init__( self._output_prefix = ( blockify(output_prefix) if output_prefix is not None else None ) + self._repair_string: str | None = None def parts(self): """Returns all of the constituent parts of an Instruction.""" @@ -132,6 +135,7 @@ def format_for_llm(self) -> TemplateRepresentation: "output_prefix": ( self._output_prefix if self._output_prefix is not None else None ), + "repair": self._repair_string, }, tools=None, template_order=["*", "Instruction"], @@ -147,3 +151,9 @@ def apply_user_dict_from_jinja(user_dict: dict[str, str], s: str) -> str: def requirements(self) -> list[Requirement]: """Returns a list of Requirement instances.""" return self._requirements + + def copy_and_repair(self, repair_string: str) -> Instruction: + """Creates a copy of the instruction and adds/overwrites the repair string.""" + res = deepcopy(self) + res._repair_string = repair_string + return res diff --git a/mellea/stdlib/sampling.py b/mellea/stdlib/sampling.py index c1dd7538..b3d1c0e4 100644 --- a/mellea/stdlib/sampling.py +++ b/mellea/stdlib/sampling.py @@ -2,12 +2,22 @@ import abc from collections.abc import Callable +from copy import deepcopy from typing import Any import tqdm +from mellea import LinearContext from mellea.helpers.fancy_logger import FancyLogger -from mellea.stdlib.base import CBlock, GenerateLog, ModelOutputThunk +from mellea.stdlib.base import ( + CBlock, + Component, + Context, + ContextTurn, + GenerateLog, + ModelOutputThunk, +) +from mellea.stdlib.chat import Message from mellea.stdlib.instruction import Instruction from mellea.stdlib.requirement import Requirement, ValidationResult @@ -23,6 +33,7 @@ def __init__( sample_generations: list[ModelOutputThunk] | None = None, sample_validations: list[list[tuple[Requirement, ValidationResult]]] | None = None, + sample_actions: list[Component] | None = None, ): """Initialize a new instance of sampling results. @@ -47,31 +58,40 @@ class SamplingStrategy(abc.ABC): """ # the function signature here matches that of m.validate - validate: Callable[[list[Requirement], Any], list[ValidationResult]] | None = None + validate: ( + Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None + ) = None generate: ( - Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None ) = None @abc.abstractmethod def sample( self, - instruction: Instruction, + action: Component, + context: Context, *, generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, ) -> SamplingResult: """This method is the abstract method for sampling a given instruction. It must be implemented by any concrete subclasses to provide specific sampling logic. Args: - instruction (Instruction): The instruction object to be sampled. + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. generate_logs: Optional list of GenerateLog objects. If None, no collection happens. + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. """ -class RejectionSamplingStrategy(SamplingStrategy): - """Sampling strategy that rejects samples based on given instructions.""" +class BaseSamplingStrategy(SamplingStrategy): + """Base class for multiple strategies that rejects samples based on given instructions.""" + + loop_budget: int def __init__( self, @@ -79,26 +99,29 @@ def __init__( loop_budget: int = 1, repair: Callable[ [ - Instruction, - list[tuple[Requirement, ValidationResult]], - list[Instruction], + Context, + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], ], - Instruction, - ] = lambda i, r, h_i: i, + Component, + ] + | None, select_from_failure: Callable[ [ - Instruction, + list[Component], list[ModelOutputThunk], list[list[tuple[Requirement, ValidationResult]]], ], - ModelOutputThunk, - ] = lambda _, results, __: results[0], - validate: Callable[[list[Requirement], Any], list[ValidationResult]] + int, + ] + | None, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None = None, generate: ( - Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None ) = None, - requirements: list[Requirement] | None = None, ): """Initialize a new instance of the class with default parameters. @@ -108,32 +131,37 @@ def __init__( select_from_failure: Function to select a model output thunk from failed attempts. validate: Function to validate the results against requirements. If None, validation is provided later through setter. generate: Function to generate new model output thunks. If None, generate is provided later through setter. - requirements: List of requirements to test against. If None, test all requirements attached to the given instruction. Raises: AssertionError: If loop_budget is not greater than 0. """ assert loop_budget > 0, "Loop budget must be at least 1." + assert repair is not None, "Repair must be provided." + assert select_from_failure is not None, "Select from failure must be provided." + self.loop_budget = loop_budget self.repair = repair self.select_from_failure = select_from_failure - self.validate = validate # it's ok to be None here - self.generate = generate # it's ok to be None here - self.requirements = requirements + self.validate = validate # it's ok to be None here. If it is None, m.instruct will set a default validator function. + self.generate = generate # it's ok to be None here. If it is None, m.instruct will set a default generator function. def sample( self, - instruction: Instruction, + action: Component, + context: Context, *, show_progress: bool = True, generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, ) -> SamplingResult: """This method performs a sampling operation based on the given instruction. Args: - instruction: The Instruction object containing the instruction to generate a valid model output thunk. - show_progress: if true, a tqdm progress bar is used. Otherwise messages will still be sent to flog. + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. generate_logs: If provided, the generations will be logged. + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. Returns: SamplingResult: A result object indicating the success or failure of the sampling process. @@ -148,68 +176,311 @@ def sample( assert self.validate is not None, "Validation must be provided." assert self.generate is not None, "Generate must be provided." + # just to be sure to not cause issues to the OG context + ctx = context.copy() + validation_ctx = validation_ctx if validation_ctx is not None else context + assert validation_ctx is not None, "Validation context must be provided." + flog = FancyLogger.get_logger() - failed_results: list[ModelOutputThunk] = [] - failed_scores: list[list[tuple[Requirement, ValidationResult]]] = [] - failed_instructions: list[Instruction] = [] + sampled_results: list[ModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] loop_count = 0 - loop_budget_range_iterator = ( - tqdm.tqdm(range(self.loop_budget)) + tqdm.tqdm(range(self.loop_budget)) # type: ignore if show_progress - else range(self.loop_budget) + else range(self.loop_budget) # type: ignore ) + + new_action = deepcopy(action) for _ in loop_budget_range_iterator: # type: ignore loop_count += 1 if not show_progress: flog.info(f"Running loop {loop_count} of {self.loop_budget}") - # run a pass + # run a generation pass + result = self.generate(new_action, ctx, generate_logs) - result = self.generate(instruction, generate_logs) + # validation pass + val_scores = self.validate(action.requirements, validation_ctx, result) - if self.requirements is not None: - reqs = self.requirements - else: - reqs = instruction.requirements - val_scores = self.validate(reqs, result) - constraint_scores = list(zip(reqs, val_scores)) + # match up reqs with scores + constraint_scores = list(zip(action.requirments, val_scores)) - failed_results.append(result) - failed_scores.append(constraint_scores) - failed_instructions.append(instruction) + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(new_action) + # if all vals are true -- break and return success if all(bool(s[1]) for s in constraint_scores): flog.info("SUCCESS") return SamplingResult( result, success=True, - sample_generations=failed_results, - sample_validations=failed_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, ) else: + # log partial success and continue count_valid = len([s for s in constraint_scores if bool(s[1])]) flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + # If we did not pass all constraints, update the instruction and try again. - instruction = self.repair( - instruction, constraint_scores, failed_instructions + new_action = self.repair( + ctx, sampled_actions, sampled_results, sampled_scores ) flog.info( - f"Invoking select_from_failure after {len(failed_results)} failed attempts." + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." ) - best_failed_result = self.select_from_failure( - instruction, failed_results, failed_scores + + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores ) - assert best_failed_result in failed_results, ( - "The select_from_failure method did not return a valid result. It has to selected from failed_results." + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to select from failed_results." ) return SamplingResult( - best_failed_result, + sampled_results[best_failed_index], success=False, - sample_generations=failed_results, - sample_validations=failed_scores, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + ) + + +class RejectionSamplingStrategy(BaseSamplingStrategy): + """Simple rejection sampling strategy with optional repair.""" + + def __init__( + self, + *, + loop_budget: int = 1, + repair: Callable[ + [ + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + Component, + ] = lambda past_actions, past_results, past_val: past_actions[-1], + select_from_failure: Callable[ + [ + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + int, + ] = lambda past_actions, past_results, past_val: 0, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + ): + def repair_wrapper(ctx, past_actions, past_results, past_val): + # ctx is not used + return repair(past_actions, past_results, past_val) + + super().__init__( + loop_budget=loop_budget, + repair=repair_wrapper, + select_from_failure=select_from_failure, + validate=validate, + generate=generate, ) + + +class AgenticSamplingStrategy(BaseSamplingStrategy): + """Rejection sampling strategy with agentic (multi-turn) repair.""" + + def __init__( + self, + *, + loop_budget: int = 1, + repair: Callable[ + [ + Context, + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + Component, + ] + | None = None, + select_from_failure: Callable[ + [ + list[Component], + list[ModelOutputThunk], + list[list[tuple[Requirement, ValidationResult]]], + ], + int, + ] = lambda past_actions, past_results, past_val: len(past_actions) - 1, + validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]] + | None = None, + generate: ( + Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk] + | None + ) = None, + ): + if repair is None: + repair = AgenticSamplingStrategy.agentic_repair_default + + super().__init__( + loop_budget=loop_budget, + repair=repair, + select_from_failure=select_from_failure, + validate=validate, + generate=generate, + ) + + @staticmethod + def agentic_repair_default( + context: Context, + past_actions: list[Component], + past_results: list[ModelOutputThunk], + past_val: list[list[tuple[Requirement, ValidationResult]]], + ) -> Component: + assert isinstance(context, LinearContext), ( + " Need linear context to run agentic sampling." + ) + + # add failed execution to chat history + context.insert_turn(ContextTurn(past_actions[-1], past_results[-1])) + + last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]] + last_failed_reqs_str = "* " + "\n* ".join( + [str(r.description) for r in last_failed_reqs] + ) + # TODO: what to do with checks ?? + + next_action = Message( + role="user", + content=f"The following requirements have not been met: \n{last_failed_reqs_str}\n Please try again to fulfill the requirements.", + ) + + return next_action + + +class RepairSamplingStrategy(RejectionSamplingStrategy): + """Rejection sampling that autoregressively generates a new sample from the previous failure.""" + + def sample( + self, + action: Component, + context: Context, + *, + show_progress: bool = True, + generate_logs: list[GenerateLog] | None = None, + validation_ctx: Context | None = None, + ) -> SamplingResult: + """This method performs a sampling operation based on the given instruction. + + Args: + action : The action object to be sampled. + context: The context to be passed to the sampling strategy. + show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog. + generate_logs: If provided, the generations will be logged. + validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx. + + Returns: + SamplingResult: A result object indicating the success or failure of the sampling process. + + Raises: + AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling. + """ + assert self.repair is not None, "Repair must be provided." + assert self.select_from_failure is not None, ( + "Select from failure must be provided." + ) + assert self.validate is not None, "Validation must be provided." + assert self.generate is not None, "Generate must be provided." + + assert isinstance(context, LinearContext), f"expecting a linear context for repairing, got: {type(context).__name__}" + + # just to be sure to not cause issues to the OG context + ctx = context.copy() + validation_ctx = validation_ctx if validation_ctx is not None else context + assert validation_ctx is not None, "Validation context must be provided." + + flog = FancyLogger.get_logger() + + sampled_results: list[ModelOutputThunk] = [] + sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = [] + sampled_actions: list[Component] = [] + + loop_count = 0 + loop_budget_range_iterator = ( + tqdm.tqdm(range(self.loop_budget)) # type: ignore + if show_progress + else range(self.loop_budget) # type: ignore + ) + + new_action = deepcopy(action) + for _ in loop_budget_range_iterator: # type: ignore + loop_count += 1 + if not show_progress: + flog.info(f"Running loop {loop_count} of {self.loop_budget}") + + # run a generation pass + result = self.generate(new_action, ctx, generate_logs) + + ctx.insert_turn(ContextTurn(new_action, result)) + + # validation pass + val_scores = self.validate(action.requirements, validation_ctx, result) + + # match up reqs with scores + constraint_scores = list(zip(action.requirements, val_scores)) + + # collect all data + sampled_results.append(result) + sampled_scores.append(constraint_scores) + sampled_actions.append(new_action) + + # if all vals are true -- break and return success + if all(bool(s[1]) for s in constraint_scores): + flog.info("SUCCESS") + return SamplingResult( + result, + success=True, + sample_generations=sampled_results, + sample_validations=sampled_scores, + ) + + else: + # log partial success and continue + count_valid = len([s for s in constraint_scores if bool(s[1])]) + flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}") + + old_action = new_action + new_action = Instruction( + "\n* ".join(["Your previous result does not follow these requirements:"] + + [ req.description for req, res in constraint_scores if not res ])) + + flog.info( + f"Invoking select_from_failure after {len(sampled_results)} failed attempts." + ) + + # if no valid result could be determined, find a last resort. + best_failed_index = self.select_from_failure( + sampled_actions, sampled_results, sampled_scores + ) + assert best_failed_index < len(sampled_results), ( + "The select_from_failure method did not return a valid result. It has to select from failed_results." + ) + return SamplingResult( + sampled_results[best_failed_index], + success=False, + sample_generations=sampled_results, + sample_validations=sampled_scores, + sample_actions=sampled_actions, + ) + + diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index cc0d3378..280ade4e 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -199,15 +199,17 @@ def instruct( generate_logs[0].is_final_result = True else: if strategy.validate is None: - strategy.validate = lambda reqs, output: self.validate( # type: ignore + strategy.validate = lambda reqs, val_ctx, output: self.validate( # type: ignore reqs, output=output, # type: ignore ) # type: ignore if strategy.generate is None: strategy.generate = ( - lambda instruction, g_logs: self.backend.generate_from_context( + lambda instruction, + gen_ctx, + g_logs: self.backend.generate_from_context( instruction, - ctx=self.ctx, + ctx=gen_ctx, format=format, model_options=model_options, generate_logs=g_logs, @@ -216,7 +218,7 @@ def instruct( ) # sample - res = strategy.sample(i, generate_logs=generate_logs) + res = strategy.sample(i, self.ctx, generate_logs=generate_logs) # make sure that one Log is marked as the one related to res.result if res.success: diff --git a/mellea/templates/prompts/default/Instruction.jinja2 b/mellea/templates/prompts/default/Instruction.jinja2 index c576b806..52d0766e 100644 --- a/mellea/templates/prompts/default/Instruction.jinja2 +++ b/mellea/templates/prompts/default/Instruction.jinja2 @@ -37,8 +37,14 @@ Here is some grounding context: {%- endif -%} {% endblock grounding_context %} +{%- block repair_block -%} +{% if repair %} +{{ repair -}} +{%- endif -%} +{% endblock repair_block %} + {%- block output_prefix -%} {% if output_prefix %} {{ output_prefix -}} {%- endif -%} -{% endblock output_prefix %} \ No newline at end of file +{% endblock output_prefix %} diff --git a/mellea/templates/prompts/granite/Instruction.jinja2 b/mellea/templates/prompts/granite/Instruction.jinja2 index ea9851e9..a2c79d6a 100644 --- a/mellea/templates/prompts/granite/Instruction.jinja2 +++ b/mellea/templates/prompts/granite/Instruction.jinja2 @@ -37,8 +37,14 @@ Here are some examples of what the response might look like: {%- endif -%} {% endblock icl_examples %} +{%- block repair_block -%} +{% if repair %} +{{ repair -}} +{%- endif -%} +{% endblock repair_block %} + {%- block output_prefix -%} {% if output_prefix %} {{ output_prefix -}} {%- endif -%} -{% endblock output_prefix %} \ No newline at end of file +{% endblock output_prefix %} diff --git a/test/sampling/test_repair.py b/test/sampling/test_repair.py new file mode 100644 index 00000000..1a05fbdf --- /dev/null +++ b/test/sampling/test_repair.py @@ -0,0 +1,135 @@ +import textwrap + +def wrapped(text: str | Iterable[str], width:int=100, fn : Callable = print): + """ + Print a text string in a reasonable screen width (100). + Swapping fn with a custom function (e.g. yp, rp, ....) is handy. + """ + if not isinstance(text, str): + for elem in text: + for subline in textwrap.wrap(str(elem), width=width, initial_indent="* ", subsequent_indent=" "): + fn(subline) + else: + for line in text.split("\n"): + for subline in textwrap.wrap(line, width=width): + fn(subline) + +import mellea +from mellea.stdlib.sampling import ( + RejectionSamplingStrategy, + RepairSamplingStrategy, +) +from mellea.stdlib.requirement import ( + Requirement, +) +from mellea.stdlib.base import ( + LinearContext, +) + +from mellea.backends import ModelOption + +import matplotlib.colors as mcolors +import colorsys + +def css_to_lightness(name: str): + """Convert a CSS color name to HSL (h, s, l). h in [0, 360], s/l in [0, 100].""" + # matplotlib gives us RGB in [0, 1] + rgb = mcolors.to_rgb(name) # tuple of floats + # convert to HLS (colorsys uses h, l, s) + h, l, s = colorsys.rgb_to_hls(*rgb) + return l + +def check_sorted(ctx): + colors = [s.strip() for s in ctx.last_output().value.split(",")] + lightnesses = map(css_to_lightness, colors) + colors_sorted = list(map(lambda x: x[0], sorted(zip(colors, lightnesses), key=lambda x: x[1]))) + print("ground truth: ",colors_sorted) + return colors == colors_sorted + + +def check_sorted12(ctx): + colors = [s.strip() for s in ctx.last_output().value.split(",")] + lightnesses = list(map(css_to_lightness, colors)) + return lightnesses[0] < lightnesses[1] + + +def check_sorted23(ctx): + colors = [s.strip() for s in ctx.last_output().value.split(",")] + lightnesses = list(map(css_to_lightness, colors)) + return lightnesses[1] < lightnesses[2] + + +def check_sorted34(ctx): + colors = [s.strip() for s in ctx.last_output().value.split(",")] + lightnesses = list(map(css_to_lightness, colors)) + return lightnesses[2] < lightnesses[3] + + + +import re + +def check_last_line(ctx) -> bool: + text: str = ctx.last_output().value + # split into lines, drop empty lines at the end + lines = text.rstrip().splitlines() + if not lines: + return False + last = lines[-1].strip() + if not last: + return False + # regex: one or more "words" separated by commas (allowing spaces) + pattern = r'^\s*[^,]+(\s*,\s*[^,]+)*\s*$' + return bool(re.match(pattern, last)) + +def check_markdown(ctx) -> bool: + text: str = ctx.last_output().value + # Patterns for common markdown styles + markdown_patterns = [ + r'\*{1,2}[^*]+\*{1,2}', # *italic* or **bold** + r'_{1,2}[^_]+_{1,2}', # _italic_ or __bold__ + r'`[^`]+`', # `inline code` + r'~~[^~]+~~', # ~~strikethrough~~ + r'^#+\s', # # heading + r'^>\s', # > blockquote + r'!\[[^\]]*\]\([^)]+\)', # ![alt](url) image + r'\[[^\]]+\]\([^)]+\)', # [text](url) link + ] + + for pat in markdown_patterns: + if re.search(pat, text, re.MULTILINE): + return False + return True + + +class Test: + m = mellea.start_session( + backend_name="ollama", + model_id="qwen3:1.7b", + model_options={ModelOption.THINKING:False}, + ctx=LinearContext(), + ) + + @pytest.mark.xfail(reason="this task is difficult") + def test_color_sort() -> str: + ans = m.instruct( + f"Sort these colors by the increasing order of lightness in HSL scale. colors: lavender, orange, purple, blue.", + requirements=[ + "Be succinct and simply return the answer in the last line without explaining the steps.", + Requirement("Do not use markdown or any other plain-text markup format.", validation_fn=check_markdown), + Requirement("Format the last line as a comma separated list without spaces", validation_fn=check_last_line), + Requirement("the first output has less lightness than the second output.", validation_fn=check_sorted12), + Requirement("the second output has less lightness than the third output.", validation_fn=check_sorted23), + Requirement("the third output has less lightness than the fourth output.", validation_fn=check_sorted34), + Requirement("the output is sorted by the lightness in the hsl model.", validation_fn=check_sorted), + ], + strategy=RepairSamplingStrategy(loop_budget=3), + return_sampling_results=True, + + ) + + assert ans.success + +if __name__ == "__main__": + # import fattrace + # fattrace.install() + pytest.main([__file__])