44from copy import deepcopy
55
66import tqdm
7- from PIL .IcnsImagePlugin import nextheader
87
98import mellea .stdlib .mellea_functions as mfuncs
10- from mellea import LegacyLinearContext
119from mellea .backends import Backend , BaseModelSubclass
1210from mellea .helpers .fancy_logger import FancyLogger
13- from mellea .stdlib .base import (
14- CBlock ,
15- ChatContext ,
16- Component ,
17- Context ,
18- ContextTurn ,
19- GenerateLog ,
20- LegacyContext ,
21- ModelOutputThunk ,
22- )
11+ from mellea .stdlib .base import CBlock , ChatContext , Component , Context , ModelOutputThunk
2312from mellea .stdlib .chat import Message
2413from mellea .stdlib .instruction import Instruction
2514from mellea .stdlib .requirement import Requirement , ScorerRequirement , ValidationResult
@@ -31,12 +20,14 @@ class SamplingResult(CBlock):
3120 def __init__ (
3221 self ,
3322 result : ModelOutputThunk ,
23+ result_ctx : Context ,
3424 success : bool ,
3525 * ,
3626 sample_generations : list [ModelOutputThunk ] | None = None ,
3727 sample_validations : list [list [tuple [Requirement , ValidationResult ]]]
3828 | None = None ,
3929 sample_actions : list [Component ] | None = None ,
30+ sample_contexts : list [Context ] | None = None ,
4031 ):
4132 """Initialize a new instance of sampling results.
4233
@@ -48,14 +39,12 @@ def __init__(
4839 """
4940 super ().__init__ (value = result .value )
5041 self .result = result
42+ self .result_ctx = result_ctx
5143 self .success = success
5244 self .sample_generations = sample_generations
5345 self .sample_validations = sample_validations
5446 self .sample_actions = sample_actions
55-
56- # TODO: JAL. add these fields
57- # context,
58- # sample_contexts=[Context] # TODO: JAL. implement this.
47+ self .sample_contexts = sample_contexts
5948
6049
6150class SamplingStrategy (abc .ABC ):
@@ -222,8 +211,8 @@ async def sample(
222211 flog .info (f"Running loop { loop_count } of { self .loop_budget } " )
223212
224213 # run a generation pass
225- # TODO: JAL. figure out where to put new_ctx ; ie we also need to return it with sampling results
226- result , new_ctx = backend .generate_from_context (
214+ # TODO: JAL. figure out where to put result_ctx ; ie we also need to return it with sampling results
215+ result , result_ctx = backend .generate_from_context (
227216 next_action ,
228217 ctx = next_context ,
229218 format = format ,
@@ -233,10 +222,10 @@ async def sample(
233222 await result .avalue ()
234223
235224 # validation pass
236- # TODO: JAL. see if we are supposed to be passing output here since the new_ctx theoretically already has it
225+ # TODO: JAL. see if we are supposed to be passing output here since the result_ctx theoretically already has it
237226 val_scores_co = mfuncs ._validate (
238227 reqs = reqs ,
239- context = new_ctx ,
228+ context = result_ctx ,
240229 backend = backend ,
241230 output = result ,
242231 format = format ,
@@ -252,7 +241,7 @@ async def sample(
252241 sampled_results .append (result )
253242 sampled_scores .append (constraint_scores )
254243 sampled_actions .append (next_action )
255- sample_contexts .append (new_ctx )
244+ sample_contexts .append (result_ctx )
256245
257246 # if all vals are true -- break and return success
258247 if all (bool (s [1 ]) for s in constraint_scores ):
@@ -262,11 +251,15 @@ async def sample(
262251 ) # Cannot be None after generation.
263252 result ._generate_log .is_final_result = True
264253
254+ # SUCCESS !!!!
265255 return SamplingResult (
266- result ,
256+ result = result ,
257+ result_ctx = result_ctx ,
267258 success = True ,
268259 sample_generations = sampled_results ,
269260 sample_validations = sampled_scores ,
261+ sample_contexts = sample_contexts ,
262+ sample_actions = sampled_actions ,
270263 )
271264
272265 else :
@@ -276,7 +269,11 @@ async def sample(
276269
277270 # If we did not pass all constraints, update the instruction and try again.
278271 next_action , next_context = self .repair (
279- next_context , new_ctx , sampled_actions , sampled_results , sampled_scores
272+ next_context ,
273+ result_ctx ,
274+ sampled_actions ,
275+ sampled_results ,
276+ sampled_scores ,
280277 )
281278
282279 flog .info (
@@ -297,11 +294,13 @@ async def sample(
297294 sampled_results [best_failed_index ]._generate_log .is_final_result = True # type: ignore
298295
299296 return SamplingResult (
300- sampled_results [best_failed_index ],
297+ result = sampled_results [best_failed_index ],
298+ result_ctx = sample_contexts [best_failed_index ],
301299 success = False ,
302300 sample_generations = sampled_results ,
303301 sample_validations = sampled_scores ,
304302 sample_actions = sampled_actions ,
303+ sample_contexts = sample_contexts ,
305304 )
306305
307306
0 commit comments