Skip to content

Commit 570dd7f

Browse files
refactoring repair and select-from-failure as abtsract methods.
1 parent decab14 commit 570dd7f

File tree

2 files changed

+106
-115
lines changed

2 files changed

+106
-115
lines changed

mellea/stdlib/sampling.py

Lines changed: 103 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -99,25 +99,6 @@ def __init__(
9999
self,
100100
*,
101101
loop_budget: int = 1,
102-
repair: Callable[
103-
[
104-
Context,
105-
list[Component],
106-
list[ModelOutputThunk],
107-
list[list[tuple[Requirement, ValidationResult]]],
108-
],
109-
Component,
110-
]
111-
| None,
112-
select_from_failure: Callable[
113-
[
114-
list[Component],
115-
list[ModelOutputThunk],
116-
list[list[tuple[Requirement, ValidationResult]]],
117-
],
118-
int,
119-
]
120-
| None,
121102
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
122103
| None = None,
123104
generate: (
@@ -130,8 +111,6 @@ def __init__(
130111
131112
Args:
132113
loop_budget: Number of times to iterate through the process. Must be greater than 0.
133-
repair: Function to apply "repairs" to an instruction based on its requirements and validation results.
134-
select_from_failure: Function to select a model output thunk from failed attempts.
135114
validate: Function to validate the results against requirements. If None, validation is provided later through setter.
136115
generate: Function to generate new model output thunks. If None, generate is provided later through setter.
137116
requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.
@@ -140,16 +119,53 @@ def __init__(
140119
AssertionError: If loop_budget is not greater than 0.
141120
"""
142121
assert loop_budget > 0, "Loop budget must be at least 1."
143-
assert repair is not None, "Repair must be provided."
144-
assert select_from_failure is not None, "Select from failure must be provided."
145122

146123
self.loop_budget = loop_budget
147-
self.repair = repair
148-
self.select_from_failure = select_from_failure
149124
self.validate = validate # it's ok to be None here
150125
self.generate = generate # it's ok to be None here
151126
self.requirements = requirements
152127

128+
@staticmethod
129+
@abc.abstractmethod
130+
def repair(
131+
ctx: Context,
132+
past_actions: list[Component],
133+
past_results: list[ModelOutputThunk],
134+
past_val: list[list[tuple[Requirement, ValidationResult]]],
135+
) -> Component:
136+
"""
137+
Repair function that is being invoked if not all requirements are fulfilled. It should return a next action component.
138+
139+
Args:
140+
ctx: The context to be passed to the sampling strategy.
141+
past_actions: List of actions that have been executed (without success).
142+
past_results: List of (unsuccessful) generation results for these actions.
143+
past_val: List of validation results for the results.
144+
145+
Returns:
146+
The next action component.
147+
"""
148+
...
149+
150+
@staticmethod
151+
@abc.abstractmethod
152+
def select_from_failure(
153+
sampled_actions: list[Component],
154+
sampled_results: list[ModelOutputThunk],
155+
sampled_val: list[list[tuple[Requirement, ValidationResult]]],
156+
):
157+
"""This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success.
158+
159+
Args:
160+
sampled_actions: List of actions that have been executed (without success).
161+
sampled_results: List of (unsuccessful) generation results for these actions.
162+
sampled_val: List of validation results for the results.
163+
164+
Returns:
165+
The index of the result that should be selected as `.value`.
166+
"""
167+
...
168+
153169
def sample(
154170
self,
155171
action: Component,
@@ -176,10 +192,6 @@ def sample(
176192
Raises:
177193
AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling.
178194
"""
179-
assert self.repair is not None, "Repair must be provided."
180-
assert self.select_from_failure is not None, (
181-
"Select from failure must be provided."
182-
)
183195
assert self.validate is not None, "Validation must be provided."
184196
assert self.generate is not None, "Generate must be provided."
185197

@@ -271,96 +283,75 @@ def sample(
271283

272284

273285
class RejectionSamplingStrategy(BaseSamplingStrategy):
274-
"""Simple rejection sampling strategy with optional repair."""
286+
"""Simple rejection sampling strategy that just repeats the same call on failure."""
275287

276-
def __init__(
277-
self,
278-
*,
279-
loop_budget: int = 1,
280-
repair: Callable[
281-
[
282-
list[Component],
283-
list[ModelOutputThunk],
284-
list[list[tuple[Requirement, ValidationResult]]],
285-
],
286-
Component,
287-
] = lambda past_actions, past_results, past_val: past_actions[-1],
288-
select_from_failure: Callable[
289-
[
290-
list[Component],
291-
list[ModelOutputThunk],
292-
list[list[tuple[Requirement, ValidationResult]]],
293-
],
294-
int,
295-
] = lambda past_actions, past_results, past_val: 0,
296-
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
297-
| None = None,
298-
generate: (
299-
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
300-
| None
301-
) = None,
302-
requirements: list[Requirement] | None = None,
303-
):
304-
def repair_wrapper(_, past_actions, past_results, past_val):
305-
return repair(past_actions, past_results, past_val)
306-
307-
super().__init__(
308-
loop_budget=loop_budget,
309-
repair=repair_wrapper,
310-
select_from_failure=select_from_failure,
311-
validate=validate,
312-
generate=generate,
313-
requirements=requirements,
314-
)
288+
@staticmethod
289+
def select_from_failure(
290+
sampled_actions: list[Component],
291+
sampled_results: list[ModelOutputThunk],
292+
sampled_val: list[list[tuple[Requirement, ValidationResult]]],
293+
) -> int:
294+
# simply returns the first attempt if all loops fail
295+
return 0
315296

297+
@staticmethod
298+
def repair(
299+
ctx: Context,
300+
past_actions: list[Component],
301+
past_results: list[ModelOutputThunk],
302+
past_val: list[list[tuple[Requirement, ValidationResult]]],
303+
) -> Component:
304+
# repeat the last action again.
305+
return past_actions[-1]
316306

317-
class AgenticSamplingStrategy(BaseSamplingStrategy):
318-
"""Rejection sampling strategy with agentic (multi-turn) repair."""
319307

320-
def __init__(
321-
self,
322-
*,
323-
loop_budget: int = 1,
324-
repair: Callable[
325-
[
326-
Context,
327-
list[Component],
328-
list[ModelOutputThunk],
329-
list[list[tuple[Requirement, ValidationResult]]],
330-
],
331-
Component,
332-
]
333-
| None = None,
334-
select_from_failure: Callable[
335-
[
336-
list[Component],
337-
list[ModelOutputThunk],
338-
list[list[tuple[Requirement, ValidationResult]]],
339-
],
340-
int,
341-
] = lambda past_actions, past_results, past_val: len(past_actions) - 1,
342-
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
343-
| None = None,
344-
generate: (
345-
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
346-
| None
347-
) = None,
348-
requirements: list[Requirement] | None = None,
308+
class RepairTemplateStrategy(BaseSamplingStrategy):
309+
"""A sampling strategy that adds a repair string to the instruction object."""
310+
311+
@staticmethod
312+
def select_from_failure(
313+
sampled_actions: list[Component],
314+
sampled_results: list[ModelOutputThunk],
315+
sampled_val: list[list[tuple[Requirement, ValidationResult]]],
316+
) -> int:
317+
# simply returns the first attempt if all loops fail
318+
return 0
319+
320+
@staticmethod
321+
def repair(
322+
ctx: Context,
323+
past_actions: list[Component],
324+
past_results: list[ModelOutputThunk],
325+
past_val: list[list[tuple[Requirement, ValidationResult]]],
326+
) -> Component:
327+
pa = past_actions[-1]
328+
if isinstance(pa, Instruction):
329+
last_failed_reqs: list[Requirement] = [
330+
s[0] for s in past_val[-1] if not s[1]
331+
]
332+
last_failed_reqs_str = "* " + "\n* ".join(
333+
[str(r.description) for r in last_failed_reqs]
334+
)
335+
return pa.copy_and_repair(
336+
repair_string=f"The following requirements failed before:\n{last_failed_reqs_str}"
337+
)
338+
return past_actions[-1]
339+
340+
341+
class MultiTurnStrategy(BaseSamplingStrategy):
342+
"""Rejection sampling strategy with (agentic) multi-turn repair."""
343+
344+
@staticmethod
345+
def select_from_failure(
346+
sampled_actions: list[Component],
347+
sampled_results: list[ModelOutputThunk],
348+
sampled_val: list[list[tuple[Requirement, ValidationResult]]],
349349
):
350-
if repair is None:
351-
repair = AgenticSamplingStrategy.agentic_repair_default
352-
353-
super().__init__(
354-
loop_budget=loop_budget,
355-
repair=repair,
356-
select_from_failure=select_from_failure,
357-
validate=validate,
358-
generate=generate,
359-
requirements=requirements,
360-
)
350+
# return the last assistant message even if all attempts of repair failed.
351+
return -1
361352

362353
@staticmethod
363-
def agentic_repair_default(
354+
def repair(
364355
context: Context,
365356
past_actions: list[Component],
366357
past_results: list[ModelOutputThunk],

test/stdlib_basics/test_sampling_ctx.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from mellea import LinearContext, start_session
22
from mellea.backends import ModelOption
33
from mellea.stdlib.sampling import (
4-
AgenticSamplingStrategy,
4+
MultiTurnStrategy,
55
RejectionSamplingStrategy,
66
SamplingResult,
77
)
@@ -45,7 +45,7 @@ def test_ctx_for_rejection_sampling(self):
4545
self._run_asserts_for_ctx_testing(res)
4646
assert len(self.m.last_prompt()) == 1, "Last prompt should only have only one instruction inside - independent of sampling iterations."
4747

48-
def test_ctx_for_agentic(self):
48+
def test_ctx_for_multiturn(self):
4949
self.m.ctx.reset()
5050
res = self.m.instruct(
5151
"Write a sentence.",
@@ -54,7 +54,7 @@ def test_ctx_for_agentic(self):
5454
"be formal",
5555
"use only words starting with the letter w",
5656
],
57-
strategy=AgenticSamplingStrategy(loop_budget=3),
57+
strategy=MultiTurnStrategy(loop_budget=3),
5858
return_sampling_results=True,
5959
)
6060

0 commit comments

Comments
 (0)