Skip to content

Commit cadaef1

Browse files
committed
Merge remote-tracking branch 'origin/main' into rename-policy
2 parents 3799593 + 4c14792 commit cadaef1

File tree

11 files changed

+1510
-153
lines changed

11 files changed

+1510
-153
lines changed

apps/grpo/main.py

Lines changed: 53 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from forge.controller.actor import ForgeActor
2929
from forge.controller.provisioner import init_provisioner, shutdown
3030
from forge.data.rewards import MathReward, ThinkingReward
31-
from forge.env import MONARCH_HOSTMESH_V1
31+
from forge.data_models.completion import Completion
3232
from forge.observability.metric_actors import get_or_create_metric_logger
3333
from forge.observability.metrics import record_metric, Reduce
3434
from forge.observability.perf_tracker import Tracer
@@ -42,73 +42,51 @@
4242

4343
@dataclass
4444
class Episode:
45-
# TODO: add adtional layer for multi-turn
4645
episode_id: str
47-
request: str
48-
policy_version: int
4946
pad_id: int
5047
request_len: int
5148
response_len: int
5249
target: Any | None = None
53-
# processed data
54-
response: str | None = None
55-
request_tokens: list[int] | None = None
56-
response_tokens: list[int] | None = None
50+
# Processed data
51+
completion: Completion | None = None
5752
ref_logprobs: torch.Tensor | None = None
5853
reward: float | None = None
5954
advantage: float | None = None
6055

6156
@property
62-
def request_tensor(self):
63-
tensor = torch.tensor(self.request_tokens, dtype=torch.long)
57+
def policy_version(self) -> int | None:
58+
return self.completion.generator_version
59+
60+
@property
61+
def request_tensor(self) -> torch.Tensor:
62+
request_tokens: torch.Tensor = self.completion.prompt_ids
63+
tensor = torch.tensor(request_tokens, dtype=torch.long)
6464
if tensor.shape[0] < self.request_len: # left pad
6565
diff = self.request_len - tensor.shape[0]
6666
tensor = F.pad(tensor, (diff, 0), value=self.pad_id)
6767
return tensor
6868

6969
@property
70-
def response_tensor(self):
71-
tensor = torch.tensor(self.response_tokens, dtype=torch.long)
70+
def response_tensor(self) -> torch.Tensor:
71+
response_tokens: torch.Tensor = self.completion.token_ids
72+
tensor = torch.tensor(response_tokens, dtype=torch.long)
7273
if tensor.shape[0] < self.response_len: # right pad
7374
diff = self.response_len - tensor.shape[0]
7475
tensor = F.pad(tensor, (0, diff), value=self.pad_id)
7576
return tensor
7677

7778

78-
@dataclass
79-
class Group:
80-
group_id: str
81-
episodes: list[Episode]
82-
83-
@classmethod
84-
def new_group(
85-
cls,
86-
group_id: int,
87-
group_size: int,
88-
request: str,
89-
policy_version: int,
90-
pad_id: int,
91-
request_len: int,
92-
response_len: int,
93-
target: Any = None,
94-
):
95-
episodes = []
96-
for _ in range(group_size):
97-
episodes.append(
98-
Episode(
99-
episode_id=str(uuid.uuid4()),
100-
request=request,
101-
policy_version=policy_version,
102-
pad_id=pad_id,
103-
request_len=request_len,
104-
response_len=response_len,
105-
target=target,
106-
)
107-
)
108-
return cls(str(group_id), episodes)
79+
# Represents the group (G) of episodes in GRPO
80+
Group = list[Episode]
10981

11082

111-
def collate(batches: list[list[Episode]]):
83+
def collate(
84+
batches: list[Group],
85+
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
86+
"""
87+
Collates a list of batches into a single batch of inputs and targets.
88+
Each batch is a list of episodes, and each episode is a dict of tensors.
89+
"""
11290
inputs = []
11391
targets = []
11492
for batch in batches:
@@ -222,7 +200,7 @@ class ComputeAdvantages(ForgeActor):
222200
@endpoint
223201
async def compute(self, group: Group) -> list[float]:
224202
# TODO: add batch processing
225-
rewards = torch.tensor([[e.reward for e in group.episodes]])
203+
rewards = torch.tensor([[e.reward for e in group]])
226204
mean = rewards.mean(1, keepdim=True)
227205
std = rewards.std(1, keepdim=True)
228206
advantages = (rewards - mean) / (std + 1e-4)
@@ -327,12 +305,6 @@ async def main(cfg: DictConfig):
327305
mlogger = await get_or_create_metric_logger()
328306
await mlogger.init_backends.call_one(metric_logging_cfg)
329307

330-
# In the host mesh v0 case, actors on remote hosts are not able to communicate
331-
# with one another. Therefore we use the controller as our storage volume.
332-
if not MONARCH_HOSTMESH_V1.get_value():
333-
await ts.initialize(strategy=ts.ControllerStorageVolumes())
334-
print("Torchstore successfully initialized with controller storage strategy")
335-
336308
# ---- Setup services ---- #
337309

338310
(
@@ -364,21 +336,19 @@ async def main(cfg: DictConfig):
364336

365337
print("All services initialized successfully!")
366338
shutdown_event = asyncio.Event()
367-
# In the HostMesh v1 case, we spawn a torchstore storage volume
368-
# per trainer process.
339+
# Here we spawn a torchstore storage volume per trainer process.
369340
# We initialize after service initialization because torchstore currently
370341
# requires access to the underlying proc meshes in the local rank strategy.
371342
# We should be able to hide this in the future.
372-
if MONARCH_HOSTMESH_V1.get_value():
373-
# TODO: support multiple host meshes
374-
trainer_num_procs = cfg.actors.trainer["procs"]
375-
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
376-
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
377-
await ts.initialize(
378-
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
379-
strategy=ts.LocalRankStrategy(),
380-
)
381-
print("Torchstore successfully initialized with local rank strategy")
343+
# TODO: support multiple host meshes
344+
trainer_num_procs = cfg.actors.trainer["procs"]
345+
trainer_host_mesh_name = cfg.actors.trainer["mesh_name"]
346+
trainer_hosts = provisioner.get_host_mesh(trainer_host_mesh_name)
347+
await ts.initialize(
348+
mesh=trainer_hosts.spawn_procs(per_host={"procs": trainer_num_procs}),
349+
strategy=ts.LocalRankStrategy(),
350+
)
351+
print("Torchstore successfully initialized with local rank strategy")
382352

383353
# ---- Core RL loops ---- #
384354
async def continuous_rollouts():
@@ -395,44 +365,32 @@ async def continuous_rollouts():
395365
t.step("data_loading")
396366

397367
prompt, target = sample["request"], sample["target"]
398-
responses = await policy.generate.route(prompt)
399-
# TODO: this shall be part of the responses metadata instead of a separate call
400-
version = await policy.get_version.route()
401-
368+
responses: list[Completion] = await policy.generate.route(prompt)
402369
t.step("policy_generation")
403370

404-
assert (
405-
len(responses) > 0
406-
), "Sanity check: Responses should NEVER return empty"
407-
assert (
408-
version := responses[0].generator_version
409-
) is not None, "Response must indicate a version"
410-
group = Group.new_group(
411-
group_id=rollout_count,
412-
group_size=group_size,
413-
request=prompt,
414-
policy_version=version,
415-
pad_id=pad_id,
416-
request_len=max_req_tokens,
417-
response_len=max_res_tokens,
418-
target=target,
419-
)
420-
371+
# Construct episodes and calculate rewards
372+
episodes = []
421373
input_ids = torch.ones(
422374
(group_size, max_req_tokens + max_res_tokens),
423375
dtype=torch.long,
424-
device="cuda",
425376
)
426-
# Populate episode info and calculate rewards
427-
for i, (episode, response) in enumerate(zip(group.episodes, responses)):
428-
episode.request_tokens = response.prompt_ids
429-
episode.response_tokens = response.token_ids
430-
episode.response = response.text
431-
input_ids[i, :max_req_tokens] = episode.request_tensor
432-
input_ids[i, max_req_tokens:] = episode.response_tensor
377+
for i, response in enumerate(responses):
378+
episode = Episode(
379+
episode_id=str(uuid.uuid4()),
380+
pad_id=pad_id,
381+
request_len=max_req_tokens,
382+
response_len=max_res_tokens,
383+
target=target,
384+
completion=response,
385+
)
433386
episode.reward = await reward_actor.evaluate_response.route(
434387
prompt=prompt, response=response.text, target=target
435388
)
389+
episodes.append(episode)
390+
391+
# Build input_ids for reference logprobs
392+
input_ids[i, :max_req_tokens] = episode.request_tensor
393+
input_ids[i, max_req_tokens:] = episode.response_tensor
436394

437395
t.step("reward_evaluation")
438396

@@ -441,14 +399,13 @@ async def continuous_rollouts():
441399
)
442400
t.step("reference_model_calculate_logprobs")
443401

444-
for i, episode in enumerate(group.episodes):
402+
for i, episode in enumerate(episodes):
445403
episode.ref_logprobs = ref_logprobs[i]
446404
del ref_logprobs, input_ids
447-
t.step("compute_logprobs")
448405

449406
# Calculate advantages and add to replay buffer
450-
advantages = await compute_advantages.compute.call_one(group)
451-
for episode, advantage in zip(group.episodes, advantages):
407+
advantages = await compute_advantages.compute.call_one(episodes)
408+
for episode, advantage in zip(episodes, advantages):
452409
episode.advantage = advantage
453410
await replay_buffer.add.call_one(episode)
454411

0 commit comments

Comments
 (0)