Skip to content

Commit e494305

Browse files
adding contexts to sampling
1 parent 590a285 commit e494305

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

mellea/stdlib/sampling.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,11 @@
44
from copy import deepcopy
55

66
import tqdm
7-
from PIL.IcnsImagePlugin import nextheader
87

98
import mellea.stdlib.mellea_functions as mfuncs
10-
from mellea import LegacyLinearContext
119
from mellea.backends import Backend, BaseModelSubclass
1210
from mellea.helpers.fancy_logger import FancyLogger
13-
from mellea.stdlib.base import (
14-
CBlock,
15-
ChatContext,
16-
Component,
17-
Context,
18-
ContextTurn,
19-
GenerateLog,
20-
LegacyContext,
21-
ModelOutputThunk,
22-
)
11+
from mellea.stdlib.base import CBlock, ChatContext, Component, Context, ModelOutputThunk
2312
from mellea.stdlib.chat import Message
2413
from mellea.stdlib.instruction import Instruction
2514
from mellea.stdlib.requirement import Requirement, ScorerRequirement, ValidationResult
@@ -31,12 +20,14 @@ class SamplingResult(CBlock):
3120
def __init__(
3221
self,
3322
result: ModelOutputThunk,
23+
result_ctx: Context,
3424
success: bool,
3525
*,
3626
sample_generations: list[ModelOutputThunk] | None = None,
3727
sample_validations: list[list[tuple[Requirement, ValidationResult]]]
3828
| None = None,
3929
sample_actions: list[Component] | None = None,
30+
sample_contexts: list[Context] | None = None,
4031
):
4132
"""Initialize a new instance of sampling results.
4233
@@ -48,14 +39,12 @@ def __init__(
4839
"""
4940
super().__init__(value=result.value)
5041
self.result = result
42+
self.result_ctx = result_ctx
5143
self.success = success
5244
self.sample_generations = sample_generations
5345
self.sample_validations = sample_validations
5446
self.sample_actions = sample_actions
55-
56-
# TODO: JAL. add these fields
57-
# context,
58-
# sample_contexts=[Context] # TODO: JAL. implement this.
47+
self.sample_contexts = sample_contexts
5948

6049

6150
class SamplingStrategy(abc.ABC):
@@ -222,8 +211,8 @@ async def sample(
222211
flog.info(f"Running loop {loop_count} of {self.loop_budget}")
223212

224213
# run a generation pass
225-
# TODO: JAL. figure out where to put new_ctx; ie we also need to return it with sampling results
226-
result, new_ctx = backend.generate_from_context(
214+
# TODO: JAL. figure out where to put result_ctx; ie we also need to return it with sampling results
215+
result, result_ctx = backend.generate_from_context(
227216
next_action,
228217
ctx=next_context,
229218
format=format,
@@ -233,10 +222,10 @@ async def sample(
233222
await result.avalue()
234223

235224
# validation pass
236-
# TODO: JAL. see if we are supposed to be passing output here since the new_ctx theoretically already has it
225+
# TODO: JAL. see if we are supposed to be passing output here since the result_ctx theoretically already has it
237226
val_scores_co = mfuncs._validate(
238227
reqs=reqs,
239-
context=new_ctx,
228+
context=result_ctx,
240229
backend=backend,
241230
output=result,
242231
format=format,
@@ -252,7 +241,7 @@ async def sample(
252241
sampled_results.append(result)
253242
sampled_scores.append(constraint_scores)
254243
sampled_actions.append(next_action)
255-
sample_contexts.append(new_ctx)
244+
sample_contexts.append(result_ctx)
256245

257246
# if all vals are true -- break and return success
258247
if all(bool(s[1]) for s in constraint_scores):
@@ -262,11 +251,15 @@ async def sample(
262251
) # Cannot be None after generation.
263252
result._generate_log.is_final_result = True
264253

254+
# SUCCESS !!!!
265255
return SamplingResult(
266-
result,
256+
result=result,
257+
result_ctx=result_ctx,
267258
success=True,
268259
sample_generations=sampled_results,
269260
sample_validations=sampled_scores,
261+
sample_contexts=sample_contexts,
262+
sample_actions=sampled_actions,
270263
)
271264

272265
else:
@@ -276,7 +269,11 @@ async def sample(
276269

277270
# If we did not pass all constraints, update the instruction and try again.
278271
next_action, next_context = self.repair(
279-
next_context, new_ctx, sampled_actions, sampled_results, sampled_scores
272+
next_context,
273+
result_ctx,
274+
sampled_actions,
275+
sampled_results,
276+
sampled_scores,
280277
)
281278

282279
flog.info(
@@ -297,11 +294,13 @@ async def sample(
297294
sampled_results[best_failed_index]._generate_log.is_final_result = True # type: ignore
298295

299296
return SamplingResult(
300-
sampled_results[best_failed_index],
297+
result=sampled_results[best_failed_index],
298+
result_ctx=sample_contexts[best_failed_index],
301299
success=False,
302300
sample_generations=sampled_results,
303301
sample_validations=sampled_scores,
304302
sample_actions=sampled_actions,
303+
sample_contexts=sample_contexts,
305304
)
306305

307306

0 commit comments

Comments
 (0)