Skip to content

Commit f02e88b

Browse files
new signature for RejectionSampling
1 parent e4b7cd4 commit f02e88b

File tree

2 files changed

+87
-48
lines changed

2 files changed

+87
-48
lines changed

mellea/stdlib/sampling.py

Lines changed: 81 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,13 @@
22

33
import abc
44
from collections.abc import Callable
5+
from copy import deepcopy
56
from typing import Any
67

78
import tqdm
89

910
from mellea.helpers.fancy_logger import FancyLogger
10-
from mellea.stdlib.base import CBlock, GenerateLog, ModelOutputThunk
11+
from mellea.stdlib.base import CBlock, Component, Context, GenerateLog, ModelOutputThunk
1112
from mellea.stdlib.instruction import Instruction
1213
from mellea.stdlib.requirement import Requirement, ValidationResult
1314

@@ -23,6 +24,7 @@ def __init__(
2324
sample_generations: list[ModelOutputThunk] | None = None,
2425
sample_validations: list[list[tuple[Requirement, ValidationResult]]]
2526
| None = None,
27+
sample_actions: list[Component] | None = None,
2628
):
2729
"""Initialize a new instance of sampling results.
2830
@@ -47,56 +49,67 @@ class SamplingStrategy(abc.ABC):
4749
"""
4850

4951
# the function signature here matches that of m.validate
50-
validate: Callable[[list[Requirement], Any], list[ValidationResult]] | None = None
52+
validate: (
53+
Callable[[list[Requirement], Context, Any], list[ValidationResult]] | None
54+
) = None
5155

5256
generate: (
53-
Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None
57+
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
58+
| None
5459
) = None
5560

5661
@abc.abstractmethod
5762
def sample(
5863
self,
59-
instruction: Instruction,
64+
action: Component,
65+
context: Context,
6066
*,
6167
generate_logs: list[GenerateLog] | None = None,
68+
validation_ctx: Context | None = None,
6269
) -> SamplingResult:
6370
"""This method is the abstract method for sampling a given instruction.
6471
6572
It must be implemented by any concrete subclasses to provide specific sampling logic.
6673
6774
Args:
68-
instruction (Instruction): The instruction object to be sampled.
75+
action : The action object to be sampled.
76+
context: The context to be passed to the sampling strategy.
6977
generate_logs: Optional list of GenerateLog objects. If None, no collection happens.
78+
validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
7079
"""
7180

7281

7382
class RejectionSamplingStrategy(SamplingStrategy):
7483
"""Sampling strategy that rejects samples based on given instructions."""
7584

85+
loop_budget: int
86+
7687
def __init__(
7788
self,
7889
*,
7990
loop_budget: int = 1,
8091
repair: Callable[
8192
[
82-
Instruction,
93+
Component,
94+
Context,
8395
list[tuple[Requirement, ValidationResult]],
84-
list[Instruction],
96+
list[Component],
8597
],
86-
Instruction,
87-
] = lambda i, r, h_i: i,
98+
Component,
99+
] = lambda i, c, r, h_i: i,
88100
select_from_failure: Callable[
89101
[
90-
Instruction,
102+
list[Component],
91103
list[ModelOutputThunk],
92104
list[list[tuple[Requirement, ValidationResult]]],
93105
],
94-
ModelOutputThunk,
95-
] = lambda _, results, __: results[0],
96-
validate: Callable[[list[Requirement], Any], list[ValidationResult]]
106+
int,
107+
] = lambda _, results, __: 0,
108+
validate: Callable[[list[Requirement], Context, Any], list[ValidationResult]]
97109
| None = None,
98110
generate: (
99-
Callable[[Instruction, list[GenerateLog] | None], ModelOutputThunk] | None
111+
Callable[[Component, Context, list[GenerateLog] | None], ModelOutputThunk]
112+
| None
100113
) = None,
101114
requirements: list[Requirement] | None = None,
102115
):
@@ -123,17 +136,23 @@ def __init__(
123136

124137
def sample(
125138
self,
126-
instruction: Instruction,
139+
action: Component,
140+
context: Context,
127141
*,
128142
show_progress: bool = True,
129143
generate_logs: list[GenerateLog] | None = None,
144+
requirements: list[Requirement] | None = None,
145+
validation_ctx: Context | None = None,
130146
) -> SamplingResult:
131147
"""This method performs a sampling operation based on the given instruction.
132148
133149
Args:
134-
instruction: The Instruction object containing the instruction to generate a valid model output thunk.
135-
show_progress: if true, a tqdm progress bar is used. Otherwise messages will still be sent to flog.
150+
action : The action object to be sampled.
151+
context: The context to be passed to the sampling strategy.
152+
show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
136153
generate_logs: If provided, the generations will be logged.
154+
requirements: List of requirements to test against.
155+
validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
137156
138157
Returns:
139158
SamplingResult: A result object indicating the success or failure of the sampling process.
@@ -148,68 +167,86 @@ def sample(
148167
assert self.validate is not None, "Validation must be provided."
149168
assert self.generate is not None, "Generate must be provided."
150169

170+
# just to be sure to not cause issues to the OG context
171+
ctx = context.copy()
172+
validation_ctx = validation_ctx if validation_ctx is not None else context
173+
assert validation_ctx is not None, "Validation context must be provided."
174+
151175
flog = FancyLogger.get_logger()
152176

153-
failed_results: list[ModelOutputThunk] = []
154-
failed_scores: list[list[tuple[Requirement, ValidationResult]]] = []
155-
failed_instructions: list[Instruction] = []
177+
sampled_results: list[ModelOutputThunk] = []
178+
sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = []
179+
sampled_actions: list[Component] = []
156180

157-
loop_count = 0
181+
if self.requirements is not None:
182+
reqs = self.requirements
183+
if requirements is not None:
184+
flog.warn("Some requirements are ignored.")
185+
else:
186+
reqs = requirements if requirements is not None else []
158187

188+
loop_count = 0
159189
loop_budget_range_iterator = (
160-
tqdm.tqdm(range(self.loop_budget))
190+
tqdm.tqdm(range(self.loop_budget)) # type: ignore
161191
if show_progress
162-
else range(self.loop_budget)
192+
else range(self.loop_budget) # type: ignore
163193
)
194+
195+
new_action = deepcopy(action)
164196
for _ in loop_budget_range_iterator: # type: ignore
165197
loop_count += 1
166198
if not show_progress:
167199
flog.info(f"Running loop {loop_count} of {self.loop_budget}")
168200

169-
# run a pass
201+
# run a generation pass
202+
result = self.generate(new_action, ctx, generate_logs)
170203

171-
result = self.generate(instruction, generate_logs)
204+
# validation pass
205+
val_scores = self.validate(reqs, validation_ctx, result)
172206

173-
if self.requirements is not None:
174-
reqs = self.requirements
175-
else:
176-
reqs = instruction.requirements
177-
val_scores = self.validate(reqs, result)
207+
# match up reqs with scores
178208
constraint_scores = list(zip(reqs, val_scores))
179209

180-
failed_results.append(result)
181-
failed_scores.append(constraint_scores)
182-
failed_instructions.append(instruction)
210+
# collect all data
211+
sampled_results.append(result)
212+
sampled_scores.append(constraint_scores)
213+
sampled_actions.append(new_action)
183214

215+
# if all vals are true -- break and return success
184216
if all(bool(s[1]) for s in constraint_scores):
185217
flog.info("SUCCESS")
186218
return SamplingResult(
187219
result,
188220
success=True,
189-
sample_generations=failed_results,
190-
sample_validations=failed_scores,
221+
sample_generations=sampled_results,
222+
sample_validations=sampled_scores,
191223
)
192224

193225
else:
226+
# log partial success and continue
194227
count_valid = len([s for s in constraint_scores if bool(s[1])])
195228
flog.info(f"FAILED. Valid: {count_valid}/{len(constraint_scores)}")
229+
196230
# If we did not pass all constraints, update the instruction and try again.
197-
instruction = self.repair(
198-
instruction, constraint_scores, failed_instructions
231+
new_action = self.repair(
232+
new_action, ctx, constraint_scores, sampled_actions
199233
)
200234

201235
flog.info(
202-
f"Invoking select_from_failure after {len(failed_results)} failed attempts."
236+
f"Invoking select_from_failure after {len(sampled_results)} failed attempts."
203237
)
204-
best_failed_result = self.select_from_failure(
205-
instruction, failed_results, failed_scores
238+
239+
# if no valid result could be determined, find a last resort.
240+
best_failed_index = self.select_from_failure(
241+
sampled_actions, sampled_results, sampled_scores
206242
)
207-
assert best_failed_result in failed_results, (
243+
assert best_failed_index < len(sampled_results), (
208244
"The select_from_failure method did not return a valid result. It has to selected from failed_results."
209245
)
210246
return SamplingResult(
211-
best_failed_result,
247+
sampled_results[best_failed_index],
212248
success=False,
213-
sample_generations=failed_results,
214-
sample_validations=failed_scores,
249+
sample_generations=sampled_results,
250+
sample_validations=sampled_scores,
251+
sample_actions=sampled_actions,
215252
)

mellea/stdlib/session.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -199,15 +199,17 @@ def instruct(
199199
generate_logs[0].is_final_result = True
200200
else:
201201
if strategy.validate is None:
202-
strategy.validate = lambda reqs, output: self.validate( # type: ignore
202+
strategy.validate = lambda reqs, val_ctx, output: self.validate( # type: ignore
203203
reqs,
204204
output=output, # type: ignore
205205
) # type: ignore
206206
if strategy.generate is None:
207207
strategy.generate = (
208-
lambda instruction, g_logs: self.backend.generate_from_context(
208+
lambda instruction,
209+
gen_ctx,
210+
g_logs: self.backend.generate_from_context(
209211
instruction,
210-
ctx=self.ctx,
212+
ctx=gen_ctx,
211213
format=format,
212214
model_options=model_options,
213215
generate_logs=g_logs,
@@ -216,7 +218,7 @@ def instruct(
216218
)
217219

218220
# sample
219-
res = strategy.sample(i, generate_logs=generate_logs)
221+
res = strategy.sample(i, self.ctx, generate_logs=generate_logs)
220222

221223
# make sure that one Log is marked as the one related to res.result
222224
if res.success:

0 commit comments

Comments
 (0)