Skip to content

Commit 791cb26

Browse files
Ritesh1905rithesh
andauthored
Policy uses the generic completion data model (#207)
* using the completions data model * working changes * few mode fixes * lints * review comments --------- Co-authored-by: rithesh <[email protected]>
1 parent 4d292cf commit 791cb26

File tree

5 files changed

+60
-35
lines changed

5 files changed

+60
-35
lines changed

apps/grpo/main.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -305,10 +305,8 @@ async def continuous_rollouts():
305305
device="cuda",
306306
)
307307
# Populate episode info and calculate rewards
308-
for i, (episode, response) in enumerate(
309-
zip(group.episodes, responses.outputs)
310-
):
311-
episode.request_tokens = responses.prompt_token_ids
308+
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
309+
episode.request_tokens = response.prompt_ids
312310
episode.response_tokens = response.token_ids
313311
episode.response = response.text
314312
input_ids[i, :max_req_tokens] = episode.request_tensor

apps/toy_rl/sumdigits.py

Lines changed: 5 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -358,7 +358,7 @@ class RewardActor(ForgeActor):
358358

359359
@endpoint
360360
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
361-
reward = 1.0 if response.strip() == "10" else 0.0
361+
reward = 1.0 if response.strip() == target else 0.0
362362
return reward
363363

364364

@@ -396,18 +396,7 @@ def generate_sample(self, step: int) -> dict[str, str]:
396396

397397
def generate_one(self, step: int) -> str:
398398
"""Generate number based on training step for curriculum learning."""
399-
if step < 200:
400-
# Early training: 2-digit numbers (10-99)
401-
min_val, max_val = 10, 99
402-
elif step < 1000:
403-
# Later training: 1-4 digit numbers (0-1000)
404-
min_val, max_val = 0, 1000
405-
elif step < 3000:
406-
# Later training: 1-6 digit numbers (0-100000)
407-
min_val, max_val = 0, 100000
408-
else:
409-
# Later training: 1-8 digit numbers (0-10000000)
410-
min_val, max_val = 0, 10000000
399+
min_val, max_val = 10, 100
411400

412401
number = random.randint(min_val, max_val)
413402
return str(number)
@@ -497,19 +486,11 @@ async def continuous_rollouts():
497486
)
498487

499488
# TODO: Parallelize the following calculation
500-
for episode, response in zip(group.episodes, responses.outputs):
501-
episode.request_tokens = responses.prompt_token_ids
489+
for episode, response in zip(group.episodes, responses):
490+
episode.request_tokens = response.prompt_ids
502491
episode.response_tokens = response.token_ids
503492
episode.response = response.text
504-
episode.response_logprobs = torch.tensor(
505-
[
506-
top_k_dict[token].logprob
507-
for token, top_k_dict in zip(
508-
response.token_ids,
509-
response.logprobs,
510-
)
511-
]
512-
)
493+
episode.response_logprobs = response.logprobs
513494
episode.ref_logprobs = await ref_model.forward.choose(episode)
514495
episode.reward = await reward_actor.evaluate_response.choose(
515496
prompt=prompt, response=response.text, target=target

apps/vllm/main.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@
1717
from forge.cli.config import parse
1818
from forge.controller.provisioner import shutdown
1919

20+
from forge.data_models.completion import Completion
2021
from omegaconf import DictConfig
21-
from vllm.outputs import RequestOutput
2222

2323
os.environ["HYPERACTOR_MESSAGE_DELIVERY_TIMEOUT_SECS"] = "600"
2424
os.environ["HYPERACTOR_CODE_MAX_FRAME_LENGTH"] = "1073741824"
@@ -36,11 +36,13 @@ async def run(cfg: DictConfig):
3636
try:
3737
async with policy.session():
3838
print("Requesting generation...")
39-
response_output: RequestOutput = await policy.generate.choose(prompt=prompt)
39+
response_output: list[Completion] = await policy.generate.choose(
40+
prompt=prompt
41+
)
4042

4143
print("\nGeneration Results:")
4244
print("=" * 80)
43-
for batch, response in enumerate(response_output.outputs):
45+
for batch, response in enumerate(response_output):
4446
print(f"Sample {batch + 1}:")
4547
print(f"User: {prompt}")
4648
print(f"Assistant: {response.text}")

src/forge/actors/policy.py

Lines changed: 43 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from vllm.entrypoints.utils import _validate_truncation_size
2626
from vllm.executor.multiproc_worker_utils import set_multiprocessing_worker_envs
2727
from vllm.lora.request import LoRARequest
28-
from vllm.outputs import RequestOutput
28+
from vllm.outputs import CompletionOutput, RequestOutput
2929
from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind, SamplingParams
3030
from vllm.transformers_utils.tokenizer_group import init_tokenizer_from_configs
3131
from vllm.usage.usage_lib import UsageContext
@@ -44,6 +44,9 @@
4444
from forge.controller import ForgeActor, get_proc_mesh, stop_proc_mesh
4545

4646
from forge.data.sharding import VLLMSharding
47+
from forge.data_models.completion import Completion
48+
from forge.data_models.prompt import to_prompt
49+
4750
from forge.interfaces import Policy as PolicyInterface
4851
from forge.types import ProcessConfig
4952

@@ -258,7 +261,7 @@ def start_processing(self):
258261
self._run_task = asyncio.create_task(self.run())
259262

260263
@endpoint
261-
async def generate(self, prompt: str, priority: int = 0) -> RequestOutput:
264+
async def generate(self, prompt: str, priority: int = 0) -> list[Completion]:
262265
"""Generate a response for the given prompt
263266
264267
Args:
@@ -362,8 +365,9 @@ async def run(self):
362365

363366
for request_output in processed_outputs.request_outputs:
364367
if request_output.finished:
368+
completions = self._to_completions(request_output)
365369
_, fut = self.requests.pop(request_output.request_id)
366-
fut.set_result(request_output)
370+
fut.set_result(completions)
367371

368372
@endpoint
369373
async def update_weights(self, policy_version: int):
@@ -396,6 +400,42 @@ async def get_version(self) -> int:
396400
async def stop(self):
397401
self.running = False
398402

403+
def _to_completions(self, request_output: RequestOutput) -> list[Completion]:
404+
"""Convert a RequestOutput to a list of Completion objects."""
405+
completions = []
406+
original_prompt = request_output.prompt
407+
prompt_token_ids = request_output.prompt_token_ids
408+
for output in request_output.outputs:
409+
completions.append(
410+
Completion(
411+
# TODO: the to_prompt encoding will be different from the original.
412+
# This is okay for now, since I don't see any direct usage of prompt using completion object.
413+
prompt=to_prompt(original_prompt),
414+
stop_reason=output.finish_reason,
415+
text=output.text,
416+
prompt_ids=torch.tensor(prompt_token_ids),
417+
token_ids=torch.tensor(output.token_ids),
418+
logprobs=self._extract_logprobs(output),
419+
)
420+
)
421+
422+
return completions
423+
424+
def _extract_logprobs(self, one_sample: CompletionOutput) -> torch.Tensor | None:
425+
"""
426+
Extract log probabilities from a sample, if available.
427+
"""
428+
if one_sample.logprobs is not None:
429+
return torch.tensor(
430+
[
431+
top_k_dict[token].logprob
432+
for token, top_k_dict in zip(
433+
one_sample.token_ids, one_sample.logprobs
434+
)
435+
]
436+
)
437+
return None
438+
399439

400440
@dataclass
401441
class PolicyWorker(ForgeActor):

src/forge/data_models/completion.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from typing import Optional
99

1010
import torch
11+
1112
from forge.data_models.prompt import Prompt
1213

1314

@@ -28,4 +29,7 @@ class Completion:
2829
token_ids: torch.Tensor
2930

3031
# the log probabilities of the target tokens
31-
log_probs: Optional[torch.Tensor] = None
32+
logprobs: Optional[torch.Tensor] = None
33+
34+
# the reason for stopping the generation
35+
stop_reason: str | None = None

0 commit comments

Comments
 (0)