Skip to content

Commit 5620d8e

Browse files
refactor: replace n parameter with sampling_params in generate()
Allows per-request override of any sampling parameter (temperature, top_p, n, etc.) instead of just n. Preserves output_kind=FINAL_ONLY enforcement from post_init logic.
1 parent 66a7ee9 commit 5620d8e

File tree

2 files changed

+11
-28
lines changed

2 files changed

+11
-28
lines changed

src/forge/actors/generator.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -288,15 +288,19 @@ def split_keys(keys):
288288

289289
@endpoint
290290
async def generate(
291-
self, prompt: str, *, priority: int = 0, n: int | None = None
291+
self,
292+
prompt: str,
293+
*,
294+
priority: int = 0,
295+
sampling_params: SamplingParams | None = None,
292296
) -> list[Completion]:
293297
"""Generate a response for the given prompt
294298
295299
Args:
296300
prompt (str): The prompt to generate a response for.
297301
priority (int, optional): The priority of the request. Defaults to 0.
298-
n (int, optional): Number of completions to generate. If not provided, uses the default
299-
from self.sampling_params.n.
302+
sampling_params (SamplingParams, optional): Sampling parameters to use for this request.
303+
If not provided, uses self.sampling_params.
300304
301305
Returns:
302306
list[Completion]: n completions from vLLM based on your prompt.
@@ -305,10 +309,10 @@ async def generate(
305309
t.start()
306310
record_metric("generator/generate/count_requests", 1, Reduce.SUM)
307311

308-
if n is not None and n != self.sampling_params.n:
309-
params = self.sampling_params.__replace__(n=n)
310-
else:
311-
params = self.sampling_params
312+
params = sampling_params or self.sampling_params
313+
# Ensure output_kind is set to FINAL_ONLY (as required by post_init)
314+
if params.output_kind != RequestOutputKind.FINAL_ONLY:
315+
params = params.__replace__(output_kind=RequestOutputKind.FINAL_ONLY)
312316

313317
self.request_id += 1 % sys.maxsize
314318
request_id = str(self.request_id)

tests/unit_tests/test_generator_config.py

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -132,27 +132,6 @@ def test_generator_yaml_config_loading(self):
132132
self.assertEqual(generator.sampling_params.n, 2)
133133
self.assertEqual(generator.sampling_params.max_tokens, 32)
134134

135-
@pytest.mark.skipif(
136-
_import_error(),
137-
reason="Import error, likely due to missing dependencies on CI.",
138-
)
139-
def test_generate_n_parameter_logic(self):
140-
from forge.actors.generator import Generator
141-
142-
generator = Generator(sampling_params={"n": 2, "max_tokens": 16})
143-
base_params = generator.sampling_params
144-
145-
def get_params_for(n_override: int | None):
146-
if n_override in (None, base_params.n):
147-
return base_params
148-
return base_params.__replace__(n=n_override)
149-
150-
self.assertIs(get_params_for(None), base_params)
151-
self.assertIs(get_params_for(2), base_params)
152-
updated = get_params_for(4)
153-
self.assertEqual(updated.n, 4)
154-
self.assertIsNot(updated, base_params)
155-
156135

157136
if __name__ == "__main__":
158137
unittest.main()

0 commit comments

Comments
 (0)