Skip to content

Commit c9de1f9

Browse files
adding RejectionSampling and AgenticSampling as subclasses of BaseSampling
1 parent f02e88b commit c9de1f9

File tree

1 file changed

+139
-8
lines changed

1 file changed

+139
-8
lines changed

mellea/stdlib/sampling.py

Lines changed: 139 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,17 @@
77

88
import tqdm
99

10+
from mellea import LinearContext
1011
from mellea.helpers.fancy_logger import FancyLogger
11-
from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk
12+
from mellea.stdlib.base import (
13+
CBlock,
14+
Component,
15+
Context,
16+
ContextTurn,
17+
GenerateLog,
18+
ModelOutputThunk,
19+
)
20+
from mellea.stdlib.chat import Message
1221
from mellea.stdlib.instruction import Instruction
1322
from mellea.stdlib.requirement import Requirement, ValidationResult
1423

@@ -79,8 +88,8 @@ def sample(
7988
"""
8089

8190

82-
class RejectionSamplingStrategy(SamplingStrategy):
83-
"""Sampling strategy that rejects samples based on given instructions."""
91+
class BaseSamplingStrategy(SamplingStrategy):
92+
"""Base class for multiple strategies that rejects samples based on given instructions."""
8493

8594
loop_budget: int
8695

@@ -90,21 +99,23 @@ def __init__(
9099
loop_budget: int = 1,
91100
repair: Callable[
92101
[
93-
Component,
94102
Context,
95-
list[tuple[Requirement, ValidationResult]],
96103
list[Component],
104+
list[ModelOutputThunk],
105+
list[list[tuple[Requirement, ValidationResult]]],
97106
],
98107
Component,
99-
] = lambda i, c, r, h_i: i,
108+
]
109+
| None,
100110
select_from_failure: Callable[
101111
[
102112
list[Component],
103113
list[ModelOutputThunk],
104114
list[list[tuple[Requirement, ValidationResult]]],
105115
],
106116
int,
107-
] = lambda _, results, __: 0,
117+
]
118+
| None,
108119
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
109120
| None = None,
110121
generate: (
@@ -127,6 +138,9 @@ def __init__(
127138
AssertionError: If loop_budget is not greater than 0.
128139
"""
129140
assert loop_budget > 0, "Loop budget must be at least 1."
141+
assert repair is not None, "Repair must be provided."
142+
assert select_from_failure is not None, "Select from failure must be provided."
143+
130144
self.loop_budget = loop_budget
131145
self.repair = repair
132146
self.select_from_failure = select_from_failure
@@ -229,7 +243,7 @@ def sample(
229243

230244
# If we did not pass all constraints, update the instruction and try again.
231245
new_action = self.repair(
232-
new_action, ctx, constraint_scores, sampled_actions
246+
ctx, sampled_actions, sampled_results, sampled_scores
233247
)
234248

235249
flog.info(
@@ -250,3 +264,120 @@ def sample(
250264
sample_validations=sampled_scores,
251265
sample_actions=sampled_actions,
252266
)
267+
268+
269+
class RejectionSamplingStrategy(BaseSamplingStrategy):
270+
"""Simple rejection sampling strategy with optional repair."""
271+
272+
def __init__(
273+
self,
274+
*,
275+
loop_budget: int = 1,
276+
repair: Callable[
277+
[
278+
list[Component],
279+
list[ModelOutputThunk],
280+
list[list[tuple[Requirement, ValidationResult]]],
281+
],
282+
Component,
283+
] = lambda past_actions, past_results, past_val: past_actions[-1],
284+
select_from_failure: Callable[
285+
[
286+
list[Component],
287+
list[ModelOutputThunk],
288+
list[list[tuple[Requirement, ValidationResult]]],
289+
],
290+
int,
291+
] = lambda past_actions, past_results, past_val: 0,
292+
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
293+
| None = None,
294+
generate: (
295+
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
296+
| None
297+
) = None,
298+
requirements: list[Requirement] | None = None,
299+
):
300+
def repair_wrapper(_, past_actions, past_results, past_val):
301+
return repair(past_actions, past_results, past_val)
302+
303+
super().__init__(
304+
loop_budget=loop_budget,
305+
repair=repair_wrapper,
306+
select_from_failure=select_from_failure,
307+
validate=validate,
308+
generate=generate,
309+
requirements=requirements,
310+
)
311+
312+
313+
class AgenticSamplingStrategy(BaseSamplingStrategy):
314+
"""Rejection sampling strategy with agentic (multi-turn) repair."""
315+
316+
def __init__(
317+
self,
318+
*,
319+
loop_budget: int = 1,
320+
repair: Callable[
321+
[
322+
Context,
323+
list[Component],
324+
list[ModelOutputThunk],
325+
list[list[tuple[Requirement, ValidationResult]]],
326+
],
327+
Component,
328+
]
329+
| None = None,
330+
select_from_failure: Callable[
331+
[
332+
list[Component],
333+
list[ModelOutputThunk],
334+
list[list[tuple[Requirement, ValidationResult]]],
335+
],
336+
int,
337+
] = lambda past_actions, past_results, past_val: len(past_actions) - 1,
338+
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
339+
| None = None,
340+
generate: (
341+
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
342+
| None
343+
) = None,
344+
requirements: list[Requirement] | None = None,
345+
):
346+
if repair is None:
347+
repair = AgenticSamplingStrategy.agentic_repair_default
348+
349+
super().__init__(
350+
loop_budget=loop_budget,
351+
repair=repair,
352+
select_from_failure=select_from_failure,
353+
validate=validate,
354+
generate=generate,
355+
requirements=requirements,
356+
)
357+
358+
@staticmethod
359+
def agentic_repair_default(
360+
context: Context,
361+
past_actions: list[Component],
362+
past_results: list[ModelOutputThunk],
363+
past_val: list[list[tuple[Requirement, ValidationResult]]],
364+
) -> Component:
365+
assert isinstance(context, LinearContext), (
366+
" Need linear context to run agentic sampling."
367+
)
368+
369+
# add failed execution to chat history
370+
context.insert_turn(ContextTurn(past_actions[-1], past_results[-1]))
371+
372+
last_failed_reqs: list[Requirement] = [s[0] for s in past_val[-1] if not s[1]]
373+
last_failed_reqs_str = "* " + "\n* ".join(
374+
[str(r.description) for r in last_failed_reqs]
375+
)
376+
# TODO: what to do with checks ??
377+
378+
next_action = Message(
379+
role="user",
380+
content=f"The following requirements have not been met: \n{last_failed_reqs_str}\n Please try again to fulfill the requirements.",
381+
)
382+
383+
return next_action

0 commit comments

Comments
 (0)