Skip to content

Commit d0f6633

Browse files
1) handling sampling results and context
2) making rejection-sampling default
1 parent e494305 commit d0f6633

File tree

2 files changed

+17
-18
lines changed

2 files changed

+17
-18
lines changed

mellea/stdlib/mellea_functions.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,11 @@
2626
from mellea.stdlib.mify import mify
2727
from mellea.stdlib.mobject import MObjectProtocol
2828
from mellea.stdlib.requirement import Requirement, ValidationResult
29-
from mellea.stdlib.sampling import SamplingResult, SamplingStrategy
29+
from mellea.stdlib.sampling import (
30+
RejectionSamplingStrategy,
31+
SamplingResult,
32+
SamplingStrategy,
33+
)
3034

3135

3236
# TODO: JAL
@@ -40,7 +44,7 @@ def act(
4044
context: Context,
4145
backend: Backend,
4246
*,
43-
strategy: SamplingStrategy | None = None,
47+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
4448
return_sampling_results: Literal[False] = False,
4549
format: type[BaseModelSubclass] | None = None,
4650
model_options: dict | None = None,
@@ -54,7 +58,7 @@ def act(
5458
context: Context,
5559
backend: Backend,
5660
*,
57-
strategy: SamplingStrategy | None = None,
61+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
5862
return_sampling_results: Literal[True],
5963
format: type[BaseModelSubclass] | None = None,
6064
model_options: dict | None = None,
@@ -68,7 +72,7 @@ def act(
6872
backend: Backend,
6973
*,
7074
requirements: list[Requirement] | None = None,
71-
strategy: SamplingStrategy | None = None,
75+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
7276
return_sampling_results: bool = False,
7377
format: type[BaseModelSubclass] | None = None,
7478
model_options: dict | None = None,
@@ -114,7 +118,7 @@ async def _act(
114118
backend: Backend,
115119
*,
116120
requirements: list[Requirement] | None = None,
117-
strategy: SamplingStrategy | None = None,
121+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
118122
return_sampling_results: bool = False,
119123
format: type[BaseModelSubclass] | None = None,
120124
model_options: dict | None = None,
@@ -144,12 +148,8 @@ async def _act(
144148
"Must provide a SamplingStrategy when return_sampling_results==True"
145149
)
146150

147-
if strategy is None:
148-
# TODO: JAL.
149-
# Change this to just call a rejection sampling strategy with loop = 1 so that we don't dupe the code
150-
# That will only work if sampling strategies let you run with no requirements...
151-
# maybe we create a sampling strategy just for this use case...
152-
# can probably just pass in an empty list and it will work...
151+
# if there is no reason to sample, just generate from the context.
152+
if strategy is None or requirements is None or len(requirements) == 0:
153153
result, new_ctx = backend.generate_from_context(
154154
action,
155155
ctx=context,
@@ -165,8 +165,7 @@ async def _act(
165165
generate_logs.append(result._generate_log)
166166

167167
else:
168-
if requirements is None:
169-
requirements = []
168+
# if there is a reason to sample, use the sampling strategy.
170169

171170
sampling_result = await strategy.sample(
172171
action,
@@ -185,7 +184,7 @@ async def _act(
185184
generate_logs.append(result._generate_log)
186185

187186
# TODO: JAL. Extract the context from the sampling result.
188-
new_ctx = ChatContext()
187+
new_ctx = sampling_result.result_ctx
189188
result = sampling_result.result
190189
assert sampling_result.result._generate_log is not None
191190
assert sampling_result.result._generate_log.is_final_result, (
@@ -214,7 +213,7 @@ def instruct(
214213
user_variables: dict[str, str] | None = None,
215214
prefix: str | CBlock | None = None,
216215
output_prefix: str | CBlock | None = None,
217-
strategy: SamplingStrategy | None = None,
216+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
218217
return_sampling_results: Literal[False] = False,
219218
format: type[BaseModelSubclass] | None = None,
220219
model_options: dict | None = None,
@@ -235,7 +234,7 @@ def instruct(
235234
user_variables: dict[str, str] | None = None,
236235
prefix: str | CBlock | None = None,
237236
output_prefix: str | CBlock | None = None,
238-
strategy: SamplingStrategy | None = None,
237+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
239238
return_sampling_results: Literal[True],
240239
format: type[BaseModelSubclass] | None = None,
241240
model_options: dict | None = None,
@@ -255,7 +254,7 @@ def instruct(
255254
user_variables: dict[str, str] | None = None,
256255
prefix: str | CBlock | None = None,
257256
output_prefix: str | CBlock | None = None,
258-
strategy: SamplingStrategy | None = None,
257+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
259258
return_sampling_results: bool = False,
260259
format: type[BaseModelSubclass] | None = None,
261260
model_options: dict | None = None,

mellea/stdlib/session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,8 +332,8 @@ def instruct(self, description: str, **kwargs) -> ModelOutputThunk | SamplingRes
332332
description, context=self.ctx, backend=self.backend, **kwargs
333333
)
334334

335-
# TODO JAL, Hen. Investigate this. What happens to context when Sampling Result returns?
336335
if isinstance(r, SamplingResult):
336+
self.ctx = r.result_ctx
337337
return r
338338
else:
339339
result, context = r

0 commit comments

Comments
 (0)