Skip to content

Commit e42059b

Browse files
author
Felipe Mello
committed
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
2 parents e3c7a99 + e901ad5 commit e42059b

29 files changed

+1979
-504
lines changed

.github/workflows/gpu_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: GPU tests
1+
name: GPU Tests
22

33
on:
44
schedule:

.github/workflows/unit_test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
name: Unit Test
1+
name: Unit Tests
22

33
on:
44
pull_request:

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
# <img width="35" height="35" alt="image" src="https://github.com/user-attachments/assets/2700a971-e5d6-4036-b03f-2f89c9791609" /> Forge
22

3-
43
#### A PyTorch-native agentic RL library that lets you focus on algorithms—not infra.
4+
[![Unit Tests](https://github.com/meta-pytorch/forge/actions/workflows/unit_test.yaml/badge.svg)](https://github.com/meta-pytorch/forge/actions/workflows/unit_test.yaml)
5+
[![GPU Tests](https://github.com/meta-pytorch/forge/actions/workflows/gpu_test.yaml/badge.svg)](https://github.com/meta-pytorch/forge/actions/workflows/gpu_test.yaml)
56

67
## Overview
78
The primary purpose of the Forge ecosystem is to delineate infra concerns from model concerns thereby making RL experimentation easier. Forge delivers this by providing clear RL abstractions and one scalable implementation of these abstractions. When you need fine-grained control over placement, fault handling/redirecting training loads during a run, or communication patterns, the primitives are there. When you don’t, you can focus purely on your RL algorithm.

apps/grpo/main.py

Lines changed: 47 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,15 @@
2020
get_dcp_whole_state_dict_key,
2121
get_param_prefix,
2222
)
23-
from forge.actors.policy import Policy
23+
from forge.actors.generator import Generator
2424
from forge.actors.reference_model import ReferenceModel
2525
from forge.actors.replay_buffer import ReplayBuffer
2626
from forge.actors.trainer import RLTrainer
2727
from forge.cli.config import parse
2828
from forge.controller.actor import ForgeActor
2929
from forge.controller.provisioner import init_provisioner, shutdown
3030
from forge.data.rewards import MathReward, ThinkingReward
31+
from forge.data_models.completion import Completion
3132
from forge.observability.metric_actors import get_or_create_metric_logger
3233
from forge.observability.metrics import record_metric, Reduce
3334
from forge.observability.perf_tracker import Tracer
@@ -41,73 +42,54 @@
4142

4243
@dataclass
4344
class Episode:
44-
# TODO: add adtional layer for multi-turn
4545
episode_id: str
46-
request: str
47-
policy_version: int
4846
pad_id: int
4947
request_len: int
5048
response_len: int
5149
target: Any | None = None
52-
# processed data
53-
response: str | None = None
54-
request_tokens: list[int] | None = None
55-
response_tokens: list[int] | None = None
50+
# Processed data
51+
completion: Completion | None = None
5652
ref_logprobs: torch.Tensor | None = None
5753
reward: float | None = None
5854
advantage: float | None = None
5955

6056
@property
61-
def request_tensor(self):
62-
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
57+
def policy_version(self) -> int | None:
58+
return self.completion.generator_version
59+
60+
@property
61+
def request_tensor(self) -> torch.Tensor:
62+
request_tokens: torch.Tensor = self.completion.prompt_ids
63+
tensor = torch.tensor(request_tokens, dtype=torch.long)
6364
if tensor.shape[0] < self.request_len: # left pad
6465
diff = self.request_len - tensor.shape[0]
6566
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
6667
return tensor
6768

6869
@property
69-
def response_tensor(self):
70-
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
70+
def response_tensor(self) -> torch.Tensor:
71+
response_tokens: torch.Tensor = self.completion.token_ids
72+
tensor = torch.tensor(response_tokens, dtype=torch.long)
7173
if tensor.shape[0] < self.response_len: # right pad
7274
diff = self.response_len - tensor.shape[0]
7375
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
7476
return tensor
7577

7678

77-
@dataclass
78-
class Group:
79-
group_id: str
80-
episodes: list[Episode]
81-
82-
@classmethod
83-
def new_group(
84-
cls,
85-
group_id: int,
86-
group_size: int,
87-
request: str,
88-
policy_version: int,
89-
pad_id: int,
90-
request_len: int,
91-
response_len: int,
92-
target: Any = None,
93-
):
94-
episodes = []
95-
for _ in range(group_size):
96-
episodes.append(
97-
Episode(
98-
episode_id=str(uuid.uuid4()),
99-
request=request,
100-
policy_version=policy_version,
101-
pad_id=pad_id,
102-
request_len=request_len,
103-
response_len=response_len,
104-
target=target,
105-
)
106-
)
107-
return cls(str(group_id), episodes)
79+
# Represents the group (G) of episodes in GRPO
80+
Group = list[Episode]
81+
82+
# Represents the Policy Model to collect data from
83+
Policy = Generator
10884

10985

110-
def collate(batches: list[list[Episode]]):
86+
def collate(
87+
batches: list[Group],
88+
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
89+
"""
90+
Collates a list of batches into a single batch of inputs and targets.
91+
Each batch is a list of episodes, and each episode is a dict of tensors.
92+
"""
11193
inputs = []
11294
targets = []
11395
for batch in batches:
@@ -221,7 +203,7 @@ class ComputeAdvantages(ForgeActor):
221203
@endpoint
222204
async def compute(self, group: Group) -> list[float]:
223205
# TODO: add batch processing
224-
rewards = torch.tensor([[e.reward for e in group.episodes]])
206+
rewards = torch.tensor([[e.reward for e in group]])
225207
mean = rewards.mean(1, keepdim=True)
226208
std = rewards.std(1, keepdim=True)
227209
advantages = (rewards - mean) / (std + 1e-4)
@@ -383,44 +365,32 @@ async def continuous_rollouts():
383365
t.step("data_loading")
384366

385367
prompt, target = sample["request"], sample["target"]
386-
responses = await policy.generate.route(prompt)
387-
# TODO: this shall be part of the responses metadata instead of a separate call
388-
version = await policy.get_version.route()
389-
368+
responses: list[Completion] = await policy.generate.route(prompt)
390369
t.step("policy_generation")
391370

392-
assert (
393-
len(responses) > 0
394-
), "Sanity check: Responses should NEVER return empty"
395-
assert (
396-
version := responses[0].generator_version
397-
) is not None, "Response must indicate a version"
398-
group = Group.new_group(
399-
group_id=rollout_count,
400-
group_size=group_size,
401-
request=prompt,
402-
policy_version=version,
403-
pad_id=pad_id,
404-
request_len=max_req_tokens,
405-
response_len=max_res_tokens,
406-
target=target,
407-
)
408-
371+
# Construct episodes and calculate rewards
372+
episodes = []
409373
input_ids = torch.ones(
410374
(group_size, max_req_tokens + max_res_tokens),
411375
dtype=torch.long,
412-
device="cuda",
413376
)
414-
# Populate episode info and calculate rewards
415-
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
416-
episode.request_tokens = response.prompt_ids
417-
episode.response_tokens = response.token_ids
418-
episode.response = response.text
419-
input_ids[i, :max_req_tokens] = episode.request_tensor
420-
input_ids[i, max_req_tokens:] = episode.response_tensor
377+
for i, response in enumerate(responses):
378+
episode = Episode(
379+
episode_id=str(uuid.uuid4()),
380+
pad_id=pad_id,
381+
request_len=max_req_tokens,
382+
response_len=max_res_tokens,
383+
target=target,
384+
completion=response,
385+
)
421386
episode.reward = await reward_actor.evaluate_response.route(
422387
prompt=prompt, response=response.text, target=target
423388
)
389+
episodes.append(episode)
390+
391+
# Build input_ids for reference logprobs
392+
input_ids[i, :max_req_tokens] = episode.request_tensor
393+
input_ids[i, max_req_tokens:] = episode.response_tensor
424394

425395
t.step("reward_evaluation")
426396

@@ -429,14 +399,13 @@ async def continuous_rollouts():
429399
)
430400
t.step("reference_model_calculate_logprobs")
431401

432-
for i, episode in enumerate(group.episodes):
402+
for i, episode in enumerate(episodes):
433403
episode.ref_logprobs = ref_logprobs[i]
434404
del ref_logprobs, input_ids
435-
t.step("compute_logprobs")
436405

437406
# Calculate advantages and add to replay buffer
438-
advantages = await compute_advantages.compute.call_one(group)
439-
for episode, advantage in zip(group.episodes, advantages):
407+
advantages = await compute_advantages.compute.call_one(episodes)
408+
for episode, advantage in zip(episodes, advantages):
440409
episode.advantage = advantage
441410
await replay_buffer.add.call_one(episode)
442411

@@ -524,22 +493,6 @@ async def continuous_training():
524493

525494
training_task.cancel()
526495

527-
# give mlogger time to shutdown backends, otherwise they can stay running.
528-
# TODO (felipemello) find more elegant solution
529-
await mlogger.shutdown.call_one()
530-
await asyncio.sleep(2)
531-
532-
await asyncio.gather(
533-
DatasetActor.shutdown(dataloader),
534-
policy.shutdown(),
535-
RLTrainer.shutdown(trainer),
536-
ReplayBuffer.shutdown(replay_buffer),
537-
ComputeAdvantages.shutdown(compute_advantages),
538-
ref_model.shutdown(),
539-
reward_actor.shutdown(),
540-
)
541-
# TODO - add a global shutdown that implicitly shuts down all services
542-
# and remote allocations
543496
await shutdown()
544497

545498

apps/grpo/qwen3_8b.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,6 @@ policy:
4242

4343
# Trainer configuration
4444
trainer:
45-
use_dcp: true
4645
model:
4746
name: qwen3
4847
flavor: 8B

0 commit comments

Comments
 (0)