77
88import tqdm
99
10+ from mellea import LinearContext
1011from mellea .helpers .fancy_logger import FancyLogger
11- from mellea .stdlib .base import CBlock , Component , Context , GenerateLog , ModelOutputThunk
12+ from mellea .stdlib .base import (
13+ CBlock ,
14+ Component ,
15+ Context ,
16+ ContextTurn ,
17+ GenerateLog ,
18+ ModelOutputThunk ,
19+ )
20+ from mellea .stdlib .chat import Message
1221from mellea .stdlib .instruction import Instruction
1322from mellea .stdlib .requirement import Requirement , ValidationResult
1423
@@ -79,8 +88,8 @@ def sample(
7988 """
8089
8190
82- class RejectionSamplingStrategy (SamplingStrategy ):
83- """Sampling strategy that rejects samples based on given instructions."""
91+ class BaseSamplingStrategy (SamplingStrategy ):
92+ """Base class for multiple strategies that rejects samples based on given instructions."""
8493
8594 loop_budget : int
8695
@@ -90,21 +99,23 @@ def __init__(
9099 loop_budget : int = 1 ,
91100 repair : Callable [
92101 [
93- Component ,
94102 Context ,
95- list [tuple [Requirement , ValidationResult ]],
96103 list [Component ],
104+ list [ModelOutputThunk ],
105+ list [list [tuple [Requirement , ValidationResult ]]],
97106 ],
98107 Component ,
99- ] = lambda i , c , r , h_i : i ,
108+ ]
109+ | None ,
100110 select_from_failure : Callable [
101111 [
102112 list [Component ],
103113 list [ModelOutputThunk ],
104114 list [list [tuple [Requirement , ValidationResult ]]],
105115 ],
106116 int ,
107- ] = lambda _ , results , __ : 0 ,
117+ ]
118+ | None ,
108119 validate : Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]]
109120 | None = None ,
110121 generate : (
@@ -127,6 +138,9 @@ def __init__(
127138 AssertionError: If loop_budget is not greater than 0.
128139 """
129140 assert loop_budget > 0 , "Loop budget must be at least 1."
141+ assert repair is not None , "Repair must be provided."
142+ assert select_from_failure is not None , "Select from failure must be provided."
143+
130144 self .loop_budget = loop_budget
131145 self .repair = repair
132146 self .select_from_failure = select_from_failure
@@ -229,7 +243,7 @@ def sample(
229243
230244 # If we did not pass all constraints, update the instruction and try again.
231245 new_action = self .repair (
232- new_action , ctx , constraint_scores , sampled_actions
246+ ctx , sampled_actions , sampled_results , sampled_scores
233247 )
234248
235249 flog .info (
@@ -250,3 +264,120 @@ def sample(
250264 sample_validations = sampled_scores ,
251265 sample_actions = sampled_actions ,
252266 )
267+
268+
269+ class RejectionSamplingStrategy (BaseSamplingStrategy ):
270+ """Simple rejection sampling strategy with optional repair."""
271+
272+ def __init__ (
273+ self ,
274+ * ,
275+ loop_budget : int = 1 ,
276+ repair : Callable [
277+ [
278+ list [Component ],
279+ list [ModelOutputThunk ],
280+ list [list [tuple [Requirement , ValidationResult ]]],
281+ ],
282+ Component ,
283+ ] = lambda past_actions , past_results , past_val : past_actions [- 1 ],
284+ select_from_failure : Callable [
285+ [
286+ list [Component ],
287+ list [ModelOutputThunk ],
288+ list [list [tuple [Requirement , ValidationResult ]]],
289+ ],
290+ int ,
291+ ] = lambda past_actions , past_results , past_val : 0 ,
292+ validate : Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]]
293+ | None = None ,
294+ generate : (
295+ Callable [[Component , Context , list [GenerateLog ] | None ], ModelOutputThunk ]
296+ | None
297+ ) = None ,
298+ requirements : list [Requirement ] | None = None ,
299+ ):
300+ def repair_wrapper (_ , past_actions , past_results , past_val ):
301+ return repair (past_actions , past_results , past_val )
302+
303+ super ().__init__ (
304+ loop_budget = loop_budget ,
305+ repair = repair_wrapper ,
306+ select_from_failure = select_from_failure ,
307+ validate = validate ,
308+ generate = generate ,
309+ requirements = requirements ,
310+ )
311+
312+
313+ class AgenticSamplingStrategy (BaseSamplingStrategy ):
314+ """Rejection sampling strategy with agentic (multi-turn) repair."""
315+
316+ def __init__ (
317+ self ,
318+ * ,
319+ loop_budget : int = 1 ,
320+ repair : Callable [
321+ [
322+ Context ,
323+ list [Component ],
324+ list [ModelOutputThunk ],
325+ list [list [tuple [Requirement , ValidationResult ]]],
326+ ],
327+ Component ,
328+ ]
329+ | None = None ,
330+ select_from_failure : Callable [
331+ [
332+ list [Component ],
333+ list [ModelOutputThunk ],
334+ list [list [tuple [Requirement , ValidationResult ]]],
335+ ],
336+ int ,
337+ ] = lambda past_actions , past_results , past_val : len (past_actions ) - 1 ,
338+ validate : Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]]
339+ | None = None ,
340+ generate : (
341+ Callable [[Component , Context , list [GenerateLog ] | None ], ModelOutputThunk ]
342+ | None
343+ ) = None ,
344+ requirements : list [Requirement ] | None = None ,
345+ ):
346+ if repair is None :
347+ repair = AgenticSamplingStrategy .agentic_repair_default
348+
349+ super ().__init__ (
350+ loop_budget = loop_budget ,
351+ repair = repair ,
352+ select_from_failure = select_from_failure ,
353+ validate = validate ,
354+ generate = generate ,
355+ requirements = requirements ,
356+ )
357+
358+ @staticmethod
359+ def agentic_repair_default (
360+ context : Context ,
361+ past_actions : list [Component ],
362+ past_results : list [ModelOutputThunk ],
363+ past_val : list [list [tuple [Requirement , ValidationResult ]]],
364+ ) -> Component :
365+ assert isinstance (context , LinearContext ), (
366+ " Need linear context to run agentic sampling."
367+ )
368+
369+ # add failed execution to chat history
370+ context .insert_turn (ContextTurn (past_actions [- 1 ], past_results [- 1 ]))
371+
372+ last_failed_reqs : list [Requirement ] = [s [0 ] for s in past_val [- 1 ] if not s [1 ]]
373+ last_failed_reqs_str = "* " + "\n * " .join (
374+ [str (r .description ) for r in last_failed_reqs ]
375+ )
376+ # TODO: what to do with checks ??
377+
378+ next_action = Message (
379+ role = "user" ,
380+ content = f"The following requirements have not been met: \n { last_failed_reqs_str } \n Please try again to fulfill the requirements." ,
381+ )
382+
383+ return next_action
0 commit comments