Skip to content

Commit 73e799b

Browse files
Merge pull request #168 from generative-computing/hen/modular_sampling
fix: convert sampling to a package
2 parents 4ae6d7c + 616b821 commit 73e799b

File tree

7 files changed

+720
-667
lines changed

7 files changed

+720
-667
lines changed

docs/examples/best_of_n/prm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
)
99
from mellea.backends.types import ModelOption
1010
from mellea.stdlib.rewards.prm_scorer import PRMScorer
11-
from mellea.stdlib.sampling import BestofNSamplingStrategy
11+
from mellea.stdlib.sampling.best_of_n import BestofNSamplingStrategy
1212

1313
# create a session for the generator using Granite 3.3 8B on Huggingface and a simple context [see below]
1414
m = start_session(backend_name="hf", model_options={ModelOption.MAX_NEW_TOKENS: 512})

mellea/stdlib/funcs.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def act(
4040
context: Context,
4141
backend: Backend,
4242
*,
43-
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
43+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2),
4444
return_sampling_results: Literal[False] = False,
4545
format: type[BaseModelSubclass] | None = None,
4646
model_options: dict | None = None,
@@ -54,7 +54,7 @@ def act(
5454
context: Context,
5555
backend: Backend,
5656
*,
57-
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
57+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2),
5858
return_sampling_results: Literal[True],
5959
format: type[BaseModelSubclass] | None = None,
6060
model_options: dict | None = None,
@@ -68,7 +68,7 @@ def act(
6868
backend: Backend,
6969
*,
7070
requirements: list[Requirement] | None = None,
71-
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
71+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2),
7272
return_sampling_results: bool = False,
7373
format: type[BaseModelSubclass] | None = None,
7474
model_options: dict | None = None,
@@ -114,7 +114,7 @@ async def _act(
114114
backend: Backend,
115115
*,
116116
requirements: list[Requirement] | None = None,
117-
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
117+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2),
118118
return_sampling_results: bool = False,
119119
format: type[BaseModelSubclass] | None = None,
120120
model_options: dict | None = None,
@@ -213,7 +213,7 @@ def instruct(
213213
user_variables: dict[str, str] | None = None,
214214
prefix: str | CBlock | None = None,
215215
output_prefix: str | CBlock | None = None,
216-
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
216+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2),
217217
return_sampling_results: Literal[False] = False,
218218
format: type[BaseModelSubclass] | None = None,
219219
model_options: dict | None = None,
@@ -234,7 +234,7 @@ def instruct(
234234
user_variables: dict[str, str] | None = None,
235235
prefix: str | CBlock | None = None,
236236
output_prefix: str | CBlock | None = None,
237-
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
237+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2),
238238
return_sampling_results: Literal[True],
239239
format: type[BaseModelSubclass] | None = None,
240240
model_options: dict | None = None,
@@ -254,7 +254,7 @@ def instruct(
254254
user_variables: dict[str, str] | None = None,
255255
prefix: str | CBlock | None = None,
256256
output_prefix: str | CBlock | None = None,
257-
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=1),
257+
strategy: SamplingStrategy | None = RejectionSamplingStrategy(loop_budget=2),
258258
return_sampling_results: bool = False,
259259
format: type[BaseModelSubclass] | None = None,
260260
model_options: dict | None = None,

0 commit comments

Comments
 (0)