22
33import abc
44from collections .abc import Callable
5+ from copy import deepcopy
56from typing import Any
67
78import tqdm
89
910from mellea .helpers .fancy_logger import FancyLogger
10- from mellea .stdlib .base import CBlock , GenerateLog , ModelOutputThunk
11+ from mellea .stdlib .base import CBlock , Component , Context , GenerateLog , ModelOutputThunk
1112from mellea .stdlib .instruction import Instruction
1213from mellea .stdlib .requirement import Requirement , ValidationResult
1314
@@ -23,6 +24,7 @@ def __init__(
2324 sample_generations : list [ModelOutputThunk ] | None = None ,
2425 sample_validations : list [list [tuple [Requirement , ValidationResult ]]]
2526 | None = None ,
27+ sample_actions : list [Component ] | None = None ,
2628 ):
2729 """Initialize a new instance of sampling results.
2830
@@ -47,56 +49,67 @@ class SamplingStrategy(abc.ABC):
4749 """
4850
4951 # the function signature here matches that of m.validate
50- validate : Callable [[list [Requirement ], Any ], list [ValidationResult ]] | None = None
52+ validate : (
53+ Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]] | None
54+ ) = None
5155
5256 generate : (
53- Callable [[Instruction , list [GenerateLog ] | None ], ModelOutputThunk ] | None
57+ Callable [[Component , Context , list [GenerateLog ] | None ], ModelOutputThunk ]
58+ | None
5459 ) = None
5560
5661 @abc .abstractmethod
5762 def sample (
5863 self ,
59- instruction : Instruction ,
64+ action : Component ,
65+ context : Context ,
6066 * ,
6167 generate_logs : list [GenerateLog ] | None = None ,
68+ validation_ctx : Context | None = None ,
6269 ) -> SamplingResult :
6370 """This method is the abstract method for sampling a given instruction.
6471
6572 It must be implemented by any concrete subclasses to provide specific sampling logic.
6673
6774 Args:
68- instruction (Instruction): The instruction object to be sampled.
75+ action : The action object to be sampled.
76+ context: The context to be passed to the sampling strategy.
6977 generate_logs: Optional list of GenerateLog objects. If None, no collection happens.
78+ validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
7079 """
7180
7281
7382class RejectionSamplingStrategy (SamplingStrategy ):
7483 """Sampling strategy that rejects samples based on given instructions."""
7584
85+ loop_budget : int
86+
7687 def __init__ (
7788 self ,
7889 * ,
7990 loop_budget : int = 1 ,
8091 repair : Callable [
8192 [
82- Instruction ,
93+ Component ,
94+ Context ,
8395 list [tuple [Requirement , ValidationResult ]],
84- list [Instruction ],
96+ list [Component ],
8597 ],
86- Instruction ,
87- ] = lambda i , r , h_i : i ,
98+ Component ,
99+ ] = lambda i , c , r , h_i : i ,
88100 select_from_failure : Callable [
89101 [
90- Instruction ,
102+ list [ Component ] ,
91103 list [ModelOutputThunk ],
92104 list [list [tuple [Requirement , ValidationResult ]]],
93105 ],
94- ModelOutputThunk ,
95- ] = lambda _ , results , __ : results [ 0 ] ,
96- validate : Callable [[list [Requirement ], Any ], list [ValidationResult ]]
106+ int ,
107+ ] = lambda _ , results , __ : 0 ,
108+ validate : Callable [[list [Requirement ], Context , Any ], list [ValidationResult ]]
97109 | None = None ,
98110 generate : (
99- Callable [[Instruction , list [GenerateLog ] | None ], ModelOutputThunk ] | None
111+ Callable [[Component , Context , list [GenerateLog ] | None ], ModelOutputThunk ]
112+ | None
100113 ) = None ,
101114 requirements : list [Requirement ] | None = None ,
102115 ):
@@ -123,17 +136,23 @@ def __init__(
123136
124137 def sample (
125138 self ,
126- instruction : Instruction ,
139+ action : Component ,
140+ context : Context ,
127141 * ,
128142 show_progress : bool = True ,
129143 generate_logs : list [GenerateLog ] | None = None ,
144+ requirements : list [Requirement ] | None = None ,
145+ validation_ctx : Context | None = None ,
130146 ) -> SamplingResult :
131147 """This method performs a sampling operation based on the given instruction.
132148
133149 Args:
134- instruction: The Instruction object containing the instruction to generate a valid model output thunk.
135- show_progress: if true, a tqdm progress bar is used. Otherwise messages will still be sent to flog.
150+ action : The action object to be sampled.
151+ context: The context to be passed to the sampling strategy.
152+ show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
136153 generate_logs: If provided, the generations will be logged.
154+ requirements: List of requirements to test against.
155+ validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
137156
138157 Returns:
139158 SamplingResult: A result object indicating the success or failure of the sampling process.
@@ -148,68 +167,86 @@ def sample(
148167 assert self .validate is not None , "Validation must be provided."
149168 assert self .generate is not None , "Generate must be provided."
150169
170+ # just to be sure to not cause issues to the OG context
171+ ctx = context .copy ()
172+ validation_ctx = validation_ctx if validation_ctx is not None else context
173+ assert validation_ctx is not None , "Validation context must be provided."
174+
151175 flog = FancyLogger .get_logger ()
152176
153- failed_results : list [ModelOutputThunk ] = []
154- failed_scores : list [list [tuple [Requirement , ValidationResult ]]] = []
155- failed_instructions : list [Instruction ] = []
177+ sampled_results : list [ModelOutputThunk ] = []
178+ sampled_scores : list [list [tuple [Requirement , ValidationResult ]]] = []
179+ sampled_actions : list [Component ] = []
156180
157- loop_count = 0
181+ if self .requirements is not None :
182+ reqs = self .requirements
183+ if requirements is not None :
184+ flog .warn ("Some requirements are ignored." )
185+ else :
186+ reqs = requirements if requirements is not None else []
158187
188+ loop_count = 0
159189 loop_budget_range_iterator = (
160- tqdm .tqdm (range (self .loop_budget ))
190+ tqdm .tqdm (range (self .loop_budget )) # type: ignore
161191 if show_progress
162- else range (self .loop_budget )
192+ else range (self .loop_budget ) # type: ignore
163193 )
194+
195+ new_action = deepcopy (action )
164196 for _ in loop_budget_range_iterator : # type: ignore
165197 loop_count += 1
166198 if not show_progress :
167199 flog .info (f"Running loop { loop_count } of { self .loop_budget } " )
168200
169- # run a pass
201+ # run a generation pass
202+ result = self .generate (new_action , ctx , generate_logs )
170203
171- result = self .generate (instruction , generate_logs )
204+ # validation pass
205+ val_scores = self .validate (reqs , validation_ctx , result )
172206
173- if self .requirements is not None :
174- reqs = self .requirements
175- else :
176- reqs = instruction .requirements
177- val_scores = self .validate (reqs , result )
207+ # match up reqs with scores
178208 constraint_scores = list (zip (reqs , val_scores ))
179209
180- failed_results .append (result )
181- failed_scores .append (constraint_scores )
182- failed_instructions .append (instruction )
210+ # collect all data
211+ sampled_results .append (result )
212+ sampled_scores .append (constraint_scores )
213+ sampled_actions .append (new_action )
183214
215+ # if all vals are true -- break and return success
184216 if all (bool (s [1 ]) for s in constraint_scores ):
185217 flog .info ("SUCCESS" )
186218 return SamplingResult (
187219 result ,
188220 success = True ,
189- sample_generations = failed_results ,
190- sample_validations = failed_scores ,
221+ sample_generations = sampled_results ,
222+ sample_validations = sampled_scores ,
191223 )
192224
193225 else :
226+ # log partial success and continue
194227 count_valid = len ([s for s in constraint_scores if bool (s [1 ])])
195228 flog .info (f"FAILED. Valid: { count_valid } /{ len (constraint_scores )} " )
229+
196230 # If we did not pass all constraints, update the instruction and try again.
197- instruction = self .repair (
198- instruction , constraint_scores , failed_instructions
231+ new_action = self .repair (
232+ new_action , ctx , constraint_scores , sampled_actions
199233 )
200234
201235 flog .info (
202- f"Invoking select_from_failure after { len (failed_results )} failed attempts."
236+ f"Invoking select_from_failure after { len (sampled_results )} failed attempts."
203237 )
204- best_failed_result = self .select_from_failure (
205- instruction , failed_results , failed_scores
238+
239+ # if no valid result could be determined, find a last resort.
240+ best_failed_index = self .select_from_failure (
241+ sampled_actions , sampled_results , sampled_scores
206242 )
207- assert best_failed_result in failed_results , (
243+ assert best_failed_index < len ( sampled_results ) , (
208244 "The select_from_failure method did not return a valid result. It has to selected from failed_results."
209245 )
210246 return SamplingResult (
211- best_failed_result ,
247+ sampled_results [ best_failed_index ] ,
212248 success = False ,
213- sample_generations = failed_results ,
214- sample_validations = failed_scores ,
249+ sample_generations = sampled_results ,
250+ sample_validations = sampled_scores ,
251+ sample_actions = sampled_actions ,
215252 )
0 commit comments