Skip to content

Commit decab14

Browse files
fixing requirements and adding tests
1 parent c9de1f9 commit decab14

File tree

3 files changed

+77
-8
lines changed

3 files changed

+77
-8
lines changed

mellea/stdlib/sampling.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def sample(
7272
self,
7373
action: Component,
7474
context: Context,
75+
requirements: list[Requirement],
7576
*,
7677
generate_logs: list[GenerateLog] | None = None,
7778
validation_ctx: Context | None = None,
@@ -83,6 +84,7 @@ def sample(
8384
Args:
8485
action : The action object to be sampled.
8586
context: The context to be passed to the sampling strategy.
87+
requirements: The requirements to be used by the sampling strategy (merged with global requirements).
8688
generate_logs: Optional list of GenerateLog objects. If None, no collection happens.
8789
validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
8890
"""
@@ -152,10 +154,10 @@ def sample(
152154
self,
153155
action: Component,
154156
context: Context,
157+
requirements: list[Requirement],
155158
*,
156159
show_progress: bool = True,
157160
generate_logs: list[GenerateLog] | None = None,
158-
requirements: list[Requirement] | None = None,
159161
validation_ctx: Context | None = None,
160162
) -> SamplingResult:
161163
"""This method performs a sampling operation based on the given instruction.
@@ -165,7 +167,7 @@ def sample(
165167
context: The context to be passed to the sampling strategy.
166168
show_progress: if true, a tqdm progress bar is used. Otherwise, messages will still be sent to flog.
167169
generate_logs: If provided, the generations will be logged.
168-
requirements: List of requirements to test against.
170+
requirements: List of requirements to test against (merged with global requirements).
169171
validation_ctx: Optional context to use for validation. If None, validation_ctx = ctx.
170172
171173
Returns:
@@ -192,12 +194,14 @@ def sample(
192194
sampled_scores: list[list[tuple[Requirement, ValidationResult]]] = []
193195
sampled_actions: list[Component] = []
194196

197+
reqs = []
198+
# global requirements supersede local requirements (global requiremenst can be defined by user)
199+
# Todo: re-evaluate if this makes sense
195200
if self.requirements is not None:
196-
reqs = self.requirements
197-
if requirements is not None:
198-
flog.warn("Some requirements are ignored.")
199-
else:
200-
reqs = requirements if requirements is not None else []
201+
reqs += self.requirements
202+
elif requirements is not None:
203+
reqs += requirements
204+
reqs = list(set(reqs))
201205

202206
loop_count = 0
203207
loop_budget_range_iterator = (

mellea/stdlib/session.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -218,7 +218,9 @@ def instruct(
218218
)
219219

220220
# sample
221-
res = strategy.sample(i, self.ctx, generate_logs=generate_logs)
221+
res = strategy.sample(
222+
i, self.ctx, i.requirements, generate_logs=generate_logs
223+
)
222224

223225
# make sure that one Log is marked as the one related to res.result
224226
if res.success:
Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
from mellea import LinearContext, start_session
2+
from mellea.backends import ModelOption
3+
from mellea.stdlib.sampling import (
4+
AgenticSamplingStrategy,
5+
RejectionSamplingStrategy,
6+
SamplingResult,
7+
)
8+
9+
10+
class TestSamplingCtxCase:
11+
m = start_session(
12+
model_options={ModelOption.MAX_NEW_TOKENS: 100}, ctx=LinearContext()
13+
)
14+
15+
def _run_asserts_for_ctx_testing(self, res):
16+
assert isinstance(res, SamplingResult), "res should be a SamplingResult."
17+
18+
assert isinstance(res.value, str), "Value should be set and a string."
19+
20+
assert len(res.sample_generations) >= 1, (
21+
"sample generation should have at least one sample."
22+
)
23+
assert len(res.sample_validations) >= 1, (
24+
"sample validation should have at least one sample."
25+
)
26+
assert len(res.sample_validations[0]) == 3, (
27+
"there should be 3 validation results."
28+
)
29+
assert len(self.m.ctx._ctx) == 2, (
30+
"there should only be a message and a response in the ctx."
31+
)
32+
33+
def test_ctx_for_rejection_sampling(self):
34+
self.m.ctx.reset()
35+
res = self.m.instruct(
36+
"Write a sentence.",
37+
requirements=[
38+
"be funny",
39+
"be formal",
40+
"use only words starting with the letter w",
41+
],
42+
strategy=RejectionSamplingStrategy(loop_budget=3),
43+
return_sampling_results=True,
44+
)
45+
self._run_asserts_for_ctx_testing(res)
46+
assert len(self.m.last_prompt()) == 1, "Last prompt should only have only one instruction inside - independent of sampling iterations."
47+
48+
def test_ctx_for_agentic(self):
49+
self.m.ctx.reset()
50+
res = self.m.instruct(
51+
"Write a sentence.",
52+
requirements=[
53+
"be funny",
54+
"be formal",
55+
"use only words starting with the letter w",
56+
],
57+
strategy=AgenticSamplingStrategy(loop_budget=3),
58+
return_sampling_results=True,
59+
)
60+
61+
self._run_asserts_for_ctx_testing(res)
62+
63+
assert len(self.m.last_prompt()) == len(res.sample_generations)*2-1, "For n sampling iterations there should be 2n-1 prompt conversation elements in the last prompt."

0 commit comments

Comments
 (0)