@@ -99,25 +99,6 @@ def __init__(
9999 self ,
100100 * ,
101101 loop_budget : int = 1 ,
102- repair : Callable [
103- [
104- Context ,
105- list [Component ],
106- list [ModelOutputThunk ],
107- list [list [tuple [Requirement , ValidationResult ]]],
108- ],
109- Component ,
110- ]
111- | None ,
112- select_from_failure : Callable [
113- [
114- list [Component ],
115- list [ModelOutputThunk ],
116- list [list [tuple [Requirement , ValidationResult ]]],
117- ],
118- int ,
119- ]
120- | None ,
121102 validate : Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]]
122103 | None = None ,
123104 generate : (
@@ -130,8 +111,6 @@ def __init__(
130111
131112 Args:
132113 loop_budget: Number of times to iterate through the process. Must be greater than 0.
133- repair: Function to apply "repairs" to an instruction based on its requirements and validation results.
134- select_from_failure: Function to select a model output thunk from failed attempts.
135114 validate: Function to validate the results against requirements. If None, validation is provided later through setter.
136115 generate: Function to generate new model output thunks. If None, generate is provided later through setter.
137116 requirements: List of requirements to test against. If None, test all requirements attached to the given instruction.
@@ -140,16 +119,53 @@ def __init__(
140119 AssertionError: If loop_budget is not greater than 0.
141120 """
142121 assert loop_budget > 0 , "Loop budget must be at least 1."
143- assert repair is not None , "Repair must be provided."
144- assert select_from_failure is not None , "Select from failure must be provided."
145122
146123 self .loop_budget = loop_budget
147- self .repair = repair
148- self .select_from_failure = select_from_failure
149124 self .validate = validate # it's ok to be None here
150125 self .generate = generate # it's ok to be None here
151126 self .requirements = requirements
152127
128+ @staticmethod
129+ @abc .abstractmethod
130+ def repair (
131+ ctx : Context ,
132+ past_actions : list [Component ],
133+ past_results : list [ModelOutputThunk ],
134+ past_val : list [list [tuple [Requirement , ValidationResult ]]],
135+ ) -> Component :
136+ """
137+ Repair function that is being invoked if not all requirements are fulfilled. It should return a next action component.
138+
139+ Args:
140+ ctx: The context to be passed to the sampling strategy.
141+ past_actions: List of actions that have been executed (without success).
142+ past_results: List of (unsuccessful) generation results for these actions.
143+ past_val: List of validation results for the results.
144+
145+ Returns:
146+ The next action component.
147+ """
148+ ...
149+
150+ @staticmethod
151+ @abc .abstractmethod
152+ def select_from_failure (
153+ sampled_actions : list [Component ],
154+ sampled_results : list [ModelOutputThunk ],
155+ sampled_val : list [list [tuple [Requirement , ValidationResult ]]],
156+ ):
157+ """This function returns the index of the result that should be selected as `.value` iff the loop budget is exhausted and no success.
158+
159+ Args:
160+ sampled_actions: List of actions that have been executed (without success).
161+ sampled_results: List of (unsuccessful) generation results for these actions.
162+ sampled_val: List of validation results for the results.
163+
164+ Returns:
165+ The index of the result that should be selected as `.value`.
166+ """
167+ ...
168+
153169 def sample (
154170 self ,
155171 action : Component ,
@@ -176,10 +192,6 @@ def sample(
176192 Raises:
177193 AssertionError: Asserts that all required components (repair, select_from_failure, validate, and generate) are provided before proceeding with the sampling.
178194 """
179- assert self .repair is not None , "Repair must be provided."
180- assert self .select_from_failure is not None , (
181- "Select from failure must be provided."
182- )
183195 assert self .validate is not None , "Validation must be provided."
184196 assert self .generate is not None , "Generate must be provided."
185197
@@ -271,96 +283,75 @@ def sample(
271283
272284
273285class RejectionSamplingStrategy (BaseSamplingStrategy ):
274- """Simple rejection sampling strategy with optional repair ."""
286+ """Simple rejection sampling strategy that just repeats the same call on failure ."""
275287
276- def __init__ (
277- self ,
278- * ,
279- loop_budget : int = 1 ,
280- repair : Callable [
281- [
282- list [Component ],
283- list [ModelOutputThunk ],
284- list [list [tuple [Requirement , ValidationResult ]]],
285- ],
286- Component ,
287- ] = lambda past_actions , past_results , past_val : past_actions [- 1 ],
288- select_from_failure : Callable [
289- [
290- list [Component ],
291- list [ModelOutputThunk ],
292- list [list [tuple [Requirement , ValidationResult ]]],
293- ],
294- int ,
295- ] = lambda past_actions , past_results , past_val : 0 ,
296- validate : Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]]
297- | None = None ,
298- generate : (
299- Callable [[Component , Context , list [GenerateLog ] | None ], ModelOutputThunk ]
300- | None
301- ) = None ,
302- requirements : list [Requirement ] | None = None ,
303- ):
304- def repair_wrapper (_ , past_actions , past_results , past_val ):
305- return repair (past_actions , past_results , past_val )
306-
307- super ().__init__ (
308- loop_budget = loop_budget ,
309- repair = repair_wrapper ,
310- select_from_failure = select_from_failure ,
311- validate = validate ,
312- generate = generate ,
313- requirements = requirements ,
314- )
288+ @staticmethod
289+ def select_from_failure (
290+ sampled_actions : list [Component ],
291+ sampled_results : list [ModelOutputThunk ],
292+ sampled_val : list [list [tuple [Requirement , ValidationResult ]]],
293+ ) -> int :
294+ # simply returns the first attempt if all loops fail
295+ return 0
315296
297+ @staticmethod
298+ def repair (
299+ ctx : Context ,
300+ past_actions : list [Component ],
301+ past_results : list [ModelOutputThunk ],
302+ past_val : list [list [tuple [Requirement , ValidationResult ]]],
303+ ) -> Component :
304+ # repeat the last action again.
305+ return past_actions [- 1 ]
316306
317- class AgenticSamplingStrategy (BaseSamplingStrategy ):
318- """Rejection sampling strategy with agentic (multi-turn) repair."""
319307
320- def __init__ (
321- self ,
322- * ,
323- loop_budget : int = 1 ,
324- repair : Callable [
325- [
326- Context ,
327- list [Component ],
328- list [ModelOutputThunk ],
329- list [list [tuple [Requirement , ValidationResult ]]],
330- ],
331- Component ,
332- ]
333- | None = None ,
334- select_from_failure : Callable [
335- [
336- list [Component ],
337- list [ModelOutputThunk ],
338- list [list [tuple [Requirement , ValidationResult ]]],
339- ],
340- int ,
341- ] = lambda past_actions , past_results , past_val : len (past_actions ) - 1 ,
342- validate : Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]]
343- | None = None ,
344- generate : (
345- Callable [[Component , Context , list [GenerateLog ] | None ], ModelOutputThunk ]
346- | None
347- ) = None ,
348- requirements : list [Requirement ] | None = None ,
308+ class RepairTemplateStrategy (BaseSamplingStrategy ):
309+ """A sampling strategy that adds a repair string to the instruction object."""
310+
311+ @staticmethod
312+ def select_from_failure (
313+ sampled_actions : list [Component ],
314+ sampled_results : list [ModelOutputThunk ],
315+ sampled_val : list [list [tuple [Requirement , ValidationResult ]]],
316+ ) -> int :
317+ # simply returns the first attempt if all loops fail
318+ return 0
319+
320+ @staticmethod
321+ def repair (
322+ ctx : Context ,
323+ past_actions : list [Component ],
324+ past_results : list [ModelOutputThunk ],
325+ past_val : list [list [tuple [Requirement , ValidationResult ]]],
326+ ) -> Component :
327+ pa = past_actions [- 1 ]
328+ if isinstance (pa , Instruction ):
329+ last_failed_reqs : list [Requirement ] = [
330+ s [0 ] for s in past_val [- 1 ] if not s [1 ]
331+ ]
332+ last_failed_reqs_str = "* " + "\n * " .join (
333+ [str (r .description ) for r in last_failed_reqs ]
334+ )
335+ return pa .copy_and_repair (
336+ repair_string = f"The following requirements failed before:\n { last_failed_reqs_str } "
337+ )
338+ return past_actions [- 1 ]
339+
340+
341+ class MultiTurnStrategy (BaseSamplingStrategy ):
342+ """Rejection sampling strategy with (agentic) multi-turn repair."""
343+
344+ @staticmethod
345+ def select_from_failure (
346+ sampled_actions : list [Component ],
347+ sampled_results : list [ModelOutputThunk ],
348+ sampled_val : list [list [tuple [Requirement , ValidationResult ]]],
349349 ):
350- if repair is None :
351- repair = AgenticSamplingStrategy .agentic_repair_default
352-
353- super ().__init__ (
354- loop_budget = loop_budget ,
355- repair = repair ,
356- select_from_failure = select_from_failure ,
357- validate = validate ,
358- generate = generate ,
359- requirements = requirements ,
360- )
350+ # return the last assistant message even if all attempts of repair failed.
351+ return - 1
361352
362353 @staticmethod
363- def agentic_repair_default (
354+ def repair (
364355 context : Context ,
365356 past_actions : list [Component ],
366357 past_results : list [ModelOutputThunk ],
0 commit comments