2626from mellea .stdlib .mify import mify
2727from mellea .stdlib .mobject import MObjectProtocol
2828from mellea .stdlib .requirement import Requirement , ValidationResult
29- from mellea .stdlib .sampling import SamplingResult , SamplingStrategy
29+ from mellea .stdlib .sampling import (
30+ RejectionSamplingStrategy ,
31+ SamplingResult ,
32+ SamplingStrategy ,
33+ )
3034
3135
3236# TODO: JAL
@@ -40,7 +44,7 @@ def act(
4044 context : Context ,
4145 backend : Backend ,
4246 * ,
43- strategy : SamplingStrategy | None = None ,
47+ strategy : SamplingStrategy | None = RejectionSamplingStrategy ( loop_budget = 1 ) ,
4448 return_sampling_results : Literal [False ] = False ,
4549 format : type [BaseModelSubclass ] | None = None ,
4650 model_options : dict | None = None ,
@@ -54,7 +58,7 @@ def act(
5458 context : Context ,
5559 backend : Backend ,
5660 * ,
57- strategy : SamplingStrategy | None = None ,
61+ strategy : SamplingStrategy | None = RejectionSamplingStrategy ( loop_budget = 1 ) ,
5862 return_sampling_results : Literal [True ],
5963 format : type [BaseModelSubclass ] | None = None ,
6064 model_options : dict | None = None ,
@@ -68,7 +72,7 @@ def act(
6872 backend : Backend ,
6973 * ,
7074 requirements : list [Requirement ] | None = None ,
71- strategy : SamplingStrategy | None = None ,
75+ strategy : SamplingStrategy | None = RejectionSamplingStrategy ( loop_budget = 1 ) ,
7276 return_sampling_results : bool = False ,
7377 format : type [BaseModelSubclass ] | None = None ,
7478 model_options : dict | None = None ,
@@ -114,7 +118,7 @@ async def _act(
114118 backend : Backend ,
115119 * ,
116120 requirements : list [Requirement ] | None = None ,
117- strategy : SamplingStrategy | None = None ,
121+ strategy : SamplingStrategy | None = RejectionSamplingStrategy ( loop_budget = 1 ) ,
118122 return_sampling_results : bool = False ,
119123 format : type [BaseModelSubclass ] | None = None ,
120124 model_options : dict | None = None ,
@@ -144,12 +148,8 @@ async def _act(
144148 "Must provide a SamplingStrategy when return_sampling_results==True"
145149 )
146150
147- if strategy is None :
148- # TODO: JAL.
149- # Change this to just call a rejection sampling strategy with loop = 1 so that we don't dupe the code
150- # That will only work if sampling strategies let you run with no requirements...
151- # maybe we create a sampling strategy just for this use case...
152- # can probably just pass in an empty list and it will work...
151+ # if there is no reason to sample, just generate from the context.
152+ if strategy is None or requirements is None or len (requirements ) == 0 :
153153 result , new_ctx = backend .generate_from_context (
154154 action ,
155155 ctx = context ,
@@ -165,8 +165,7 @@ async def _act(
165165 generate_logs .append (result ._generate_log )
166166
167167 else :
168- if requirements is None :
169- requirements = []
168+ # if there is a reason to sample, use the sampling strategy.
170169
171170 sampling_result = await strategy .sample (
172171 action ,
@@ -185,7 +184,7 @@ async def _act(
185184 generate_logs .append (result ._generate_log )
186185
187186 # TODO: JAL. Extract the context from the sampling result.
188- new_ctx = ChatContext ()
187+ new_ctx = sampling_result . result_ctx
189188 result = sampling_result .result
190189 assert sampling_result .result ._generate_log is not None
191190 assert sampling_result .result ._generate_log .is_final_result , (
@@ -214,7 +213,7 @@ def instruct(
214213 user_variables : dict [str , str ] | None = None ,
215214 prefix : str | CBlock | None = None ,
216215 output_prefix : str | CBlock | None = None ,
217- strategy : SamplingStrategy | None = None ,
216+ strategy : SamplingStrategy | None = RejectionSamplingStrategy ( loop_budget = 1 ) ,
218217 return_sampling_results : Literal [False ] = False ,
219218 format : type [BaseModelSubclass ] | None = None ,
220219 model_options : dict | None = None ,
@@ -235,7 +234,7 @@ def instruct(
235234 user_variables : dict [str , str ] | None = None ,
236235 prefix : str | CBlock | None = None ,
237236 output_prefix : str | CBlock | None = None ,
238- strategy : SamplingStrategy | None = None ,
237+ strategy : SamplingStrategy | None = RejectionSamplingStrategy ( loop_budget = 1 ) ,
239238 return_sampling_results : Literal [True ],
240239 format : type [BaseModelSubclass ] | None = None ,
241240 model_options : dict | None = None ,
@@ -255,7 +254,7 @@ def instruct(
255254 user_variables : dict [str , str ] | None = None ,
256255 prefix : str | CBlock | None = None ,
257256 output_prefix : str | CBlock | None = None ,
258- strategy : SamplingStrategy | None = None ,
257+ strategy : SamplingStrategy | None = RejectionSamplingStrategy ( loop_budget = 1 ) ,
259258 return_sampling_results : bool = False ,
260259 format : type [BaseModelSubclass ] | None = None ,
261260 model_options : dict | None = None ,
0 commit comments