Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,14 +179,19 @@ def simple_grpo_loss(
loss = -(mean_policy_loss - beta * mean_kl)

# Log metrics
# TODO: Better design - have loss function return all metrics as a dict,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I want to get away from TODOs in code as a marker. Can you turn this into a small GI?

# then record them in rl_trainer so all training metrics are in one namespace
# and we avoid doing .item here, which is not compile friendly
record_metric("grpo_loss/kl_divergence_mean", mean_kl.item(), Reduce.MEAN)
record_metric(
"grpo_loss/kl_divergence_max", (kl * padding_mask).max().item(), Reduce.MAX
)
record_metric("grpo_loss/policy_loss", mean_policy_loss.item(), Reduce.MEAN)
record_metric(
"grpo_loss/policy_gradient_loss", mean_policy_loss.item(), Reduce.MEAN
)
record_metric("grpo_loss/total_loss", loss.item(), Reduce.MEAN)
record_metric("grpo_loss/advantage_mean", advantages.mean().item(), Reduce.MEAN)
record_metric("grpo_loss/advantage_std", advantages.std().item(), Reduce.MEAN)

return loss


Expand All @@ -210,11 +215,6 @@ async def evaluate_response(
)
reward_breakdown[reward_fn_name] = reward
# per function reward
record_metric(
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
reward,
Reduce.SUM,
)
record_metric(
f"reward/evaluate_response/avg_{reward_fn_name}_reward",
reward,
Expand All @@ -226,18 +226,13 @@ async def evaluate_response(
Reduce.STD,
)

# avg total reward
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove. It's already called avg_total_reward

record_metric(
"reward/evaluate_response/avg_total_reward",
reward,
Reduce.MEAN,
)

record_metric(
f"reward/evaluate_response/count_{reward_fn_name}_calls",
1,
Reduce.SUM,
)

avg_reward: float = total_rewards / len(self.reward_functions)
return reward_breakdown, avg_reward

Expand Down Expand Up @@ -305,17 +300,6 @@ async def sample(self) -> dict[str, str] | None:
try:
sample = next(self._iterator)

record_metric("dataset/sample/count_samples_generated", 1, Reduce.SUM)
record_metric(
"dataset/sample/avg_sample_len",
len(sample["request"]),
Reduce.MEAN,
)
record_metric(
"dataset/sample/max_sample_len",
len(sample["request"]),
Reduce.MAX,
)
record_metric("dataset/sample/current_epoch", self._epoch, Reduce.MAX)

return sample
Expand Down Expand Up @@ -442,8 +426,6 @@ async def continuous_rollouts():
print("Dataloader is empty, exiting continuous rollout")
return

t.step("data_loading")

prompt, target = sample["request"], sample["target"]
responses: list[Completion] = await policy.generate.route(prompt)
t.step("policy_generation")
Expand Down Expand Up @@ -477,6 +459,23 @@ async def continuous_rollouts():
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor

# Track token-based metrics
prompt_tokens = episode.completion.prompt_ids.shape[0]
response_tokens = episode.completion.token_ids.shape[0]

record_metric("episode/avg_prompt_tokens", prompt_tokens, Reduce.MEAN)
record_metric("episode/max_prompt_tokens", prompt_tokens, Reduce.MAX)
record_metric("episode/min_prompt_tokens", prompt_tokens, Reduce.MIN)
record_metric(
"episode/avg_response_tokens", response_tokens, Reduce.MEAN
)
record_metric(
"episode/max_response_tokens", response_tokens, Reduce.MAX
)
record_metric(
"episode/min_response_tokens", response_tokens, Reduce.MIN
)

# drop episodes if
# 1> reward std-dev is very small (including all 0s and all 1s)
# 2> response is potentially truncated (response_len >= max_res_tokens)
Expand All @@ -485,7 +484,7 @@ async def continuous_rollouts():
max_response_len = max(e.completion.token_ids.shape[0] for e in episodes)
drop = rewards_std < 1e-3 or max_response_len >= max_res_tokens
record_metric(
"main/continuous_rollouts/dropped_episodes",
"main/continuous_rollouts/unfit_for_training_dropped_episodes",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This naming is a little confusing. Can you explain?

1 if drop else 0,
Reduce.SUM,
)
Expand Down
59 changes: 15 additions & 44 deletions src/forge/actors/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import logging
import os
import sys
import time
from collections.abc import Mapping
from copy import copy
from dataclasses import dataclass, field
Expand Down Expand Up @@ -258,8 +259,6 @@ async def _fetch_weights(
version: int,
) -> dict[str, SharedTensorHandle]:
"""Fetch weights from torchstore and return a dict of {name: SharedTensorHandle}."""
t = Tracer("generator_perf/_fetch_weights")
t.start()
prefix = get_param_prefix(version)
matching_keys = await ts.keys(prefix)
hf_param_names = [extract_param_name(key) for key in matching_keys]
Expand All @@ -282,8 +281,6 @@ def split_keys(keys):
for sd in sub_state_dicts:
state_dict.update(sd)

t.stop()

return state_dict

@endpoint
Expand Down Expand Up @@ -336,8 +333,6 @@ async def generate(
priority=priority,
data_parallel_rank=None, # We do not support DP
)
t.step("process_inputs")

# Wait until we're accepting requests (releases lock while waiting)
# If accepting_requests is True, continue immediately (holding the lock)
# If False, release lock, wait for notification, re-acquire and recheck
Expand Down Expand Up @@ -369,7 +364,6 @@ async def generate(
self.requests[request_id] = (parent_req, request_fut)

completions = await request_fut
t.step("generate")

# Log some metrics
record_metric(
Expand All @@ -378,19 +372,6 @@ async def generate(
Reduce.SUM,
)

for completion in completions:
num_generated_tokens = len(completion.token_ids)
record_metric(
"generator/generate/sum_tokens_generated",
num_generated_tokens,
Reduce.SUM,
)

record_metric(
"generator/generate/avg_tokens_generated",
num_generated_tokens,
Reduce.MEAN,
)
t.stop()
return completions

Expand Down Expand Up @@ -465,37 +446,36 @@ async def update_weights(self, version: int) -> None:
async with self.request_lock:
self.accepting_requests = False
curr_requests = [fut for _, fut in self.requests.values()]

if curr_requests:
# Record pending requests metrics
record_metric(
"generator_perf/update_weights/avg_pending_requests",
len(curr_requests),
Reduce.MEAN,
)
# Record pending requests count
record_metric(
"generator_perf/update_weights/max_pending_requests",
"generator_perf/update_weights/sum_pending_gen_requests",
len(curr_requests),
Reduce.MAX,
Reduce.SUM,
)
logger.debug(f"Waiting for {len(curr_requests)} pending requests")

# Start timing the wait
wait_start = time.perf_counter()

# Wait until all pending requests have been processed
# TODO: If generating long sequences, this might be long and will block
# generator weight updates
await self.request_lock.wait_for(lambda: len(self.requests) == 0)

# Record weight update metrics
record_metric(
"generator/update_weights/count_weight_updates", 1, Reduce.SUM
)
if curr_requests:
wait_duration = time.perf_counter() - wait_start
record_metric(
"generator_perf/update_weights/avg_waiting_for_generation_duration_s",
wait_duration,
Reduce.MEAN,
)

logger.debug(f"Starting weight update on {self.__class__.__name__}")

if fetch_fut is not None:
t = Tracer("generator_perf/waiting_for_fetch_weights")
t.start()
fetched_weights = await fetch_fut
t.stop()
# Call update_weights on every policy_worker
await self.worker.update_weights.call(
shared_memory_state_dict=fetched_weights
Expand Down Expand Up @@ -672,10 +652,6 @@ async def update_weights(
model = self.worker.model_runner.model
if shared_memory_state_dict is not None:
logger.info("[PolicyWorker] update weights from shared memory.")
t = Tracer(
"generator_worker_perf/update_weights_from_shared_memory", timer="gpu"
)
t.start()
loaded_weights = set()
for name, param_handle in shared_memory_state_dict.items():
# Use context manager for automatic cleanup
Expand All @@ -685,7 +661,6 @@ async def update_weights(
del param
loaded_weights.update(loaded)
logger.info(f"[PolicyWorker] updated {len(loaded_weights)} parameters")
t.stop()
return
# normal update_weights without shared memory prefetching
if version is None:
Expand All @@ -698,8 +673,6 @@ async def update_weights(
dcp_whole_state_dict_key = get_dcp_whole_state_dict_key(version)
use_dcp_for_weight_sync = dcp_whole_state_dict_key in matching_keys
loaded_weights = set()
t = Tracer("generator_worker_perf/update_weights_from_torchstore", timer="gpu")
t.start()

if use_dcp_for_weight_sync:
dcp_handle = await ts.get(dcp_whole_state_dict_key)
Expand All @@ -720,8 +693,6 @@ async def update_weights(
del param
loaded_weights.update(loaded)

t.stop()

@endpoint
async def save_model_params(self):
"""Save model parameters before weight update, used for testing purposes only."""
Expand Down
10 changes: 1 addition & 9 deletions src/forge/actors/reference_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,21 +144,15 @@ async def forward(
"""
# Record reference model metrics
record_metric("reference_perf/forward/count_forward_passes", 1, Reduce.SUM)
record_metric(
"reference_perf/forward/avg_sequence_length",
input_ids.shape[1],
Reduce.MEAN,
)

t = Tracer("reference_perf/forward", timer="gpu", track_memory=True)
t.start()
self.engine.gc_handler.run(self.step)
t.step("garbage_collection")

model_parts = self.engine.model_parts
parallel_dims = self.engine.parallel_dims
input_ids = input_ids.to("cuda")
t.step("to_device")

# optional_context_parallel_ctx = (
# dist_utils.create_context_parallel_ctx(
# cp_mesh=parallel_dims.world_mesh["cp"],
Expand All @@ -182,13 +176,11 @@ async def forward(
self.step += 1
if isinstance(logits, DTensor):
logits = logits.full_tensor()
t.step("forward")

if not return_logprobs:
t.stop()
return logits
else:
logprobs = compute_logprobs(logits, input_ids[:, max_req_tokens:])
t.step("compute_logprobs")
t.stop()
return logprobs
12 changes: 1 addition & 11 deletions src/forge/actors/replay_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@

from forge.controller import ForgeActor
from forge.observability.metrics import record_metric, Reduce
from forge.observability.perf_tracker import trace

from monarch.actor import endpoint

Expand Down Expand Up @@ -75,7 +74,6 @@ async def add(self, episode: "Episode") -> None:
record_metric("buffer/add/count_episodes_added", 1, Reduce.SUM)

@endpoint
@trace("buffer_perf/sample", track_memory=False)
async def sample(
self, curr_policy_version: int
) -> tuple[tuple[Any, ...], ...] | None:
Expand All @@ -87,8 +85,6 @@ async def sample(
Returns:
A list of sampled episodes with shape (dp_size, bsz, ...) or None if there are not enough episodes in the buffer.
"""
# Record sample request metric
record_metric("buffer/sample/count_sample_requests", 1, Reduce.SUM)

total_samples = self.dp_size * self.batch_size

Expand All @@ -98,7 +94,7 @@ async def sample(
# Calculate metrics
if len(self.buffer) > 0:
record_metric(
"buffer/sample/avg_data_utilization",
"buffer/sample/demand_to_size_ratio",
total_samples / len(self.buffer),
Reduce.MEAN,
)
Expand Down Expand Up @@ -135,12 +131,6 @@ async def sample(
max(sampled_policy_ages),
Reduce.MAX,
)
record_metric(
"buffer/sample/min_sampled_policy_age",
min(sampled_policy_ages),
Reduce.MIN,
)

# Reshape into (dp_size, bsz, ...)
reshaped_episodes = [
sampled_episodes[dp_idx * self.batch_size : (dp_idx + 1) * self.batch_size]
Expand Down
9 changes: 1 addition & 8 deletions src/forge/actors/trainer/titan.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async def train_step(

# TODO: delete item() to avoid cpu-gpu sync
loss = loss.detach().item()
record_metric("rl_trainer/avg_loss", loss, Reduce.MEAN)
record_metric("rl_trainer/loss", loss, Reduce.MEAN)

# These are placeholder values until the loss function exposes these metrics
# record_metric("rl_trainer/step/avg_kl_divergence", 0.0, Reduce.MEAN)
Expand All @@ -195,8 +195,6 @@ async def train_step(
@endpoint
async def push_weights(self, policy_version: int) -> None:
"""Push weights to torchstore in HF format."""
t = Tracer("rl_trainer_perf/push_weights", timer="gpu", track_memory=True)
t.start()
logger.info(f"Pushing weights for policy version {policy_version}")

start_time = time.perf_counter()
Expand All @@ -205,13 +203,11 @@ async def push_weights(self, policy_version: int) -> None:

sd = self.engine.checkpointer.states["model"].state_dict()
flattened_state_dict, _ = flatten_state_dict(sd)
t.step("flatten_state_dict")
if self.engine.checkpointer.sd_adapter is None:
raise RuntimeError(
"Trying to save checkpoint in HF safetensors format, but sd_adapter is not provided."
)
hf_state_dict = self.engine.checkpointer.sd_adapter.to_hf(flattened_state_dict)
t.step("to_hf")
if self.use_dcp:
key = get_dcp_whole_state_dict_key(policy_version)
dcp_id = f"{self.dcp_path}/{key}"
Expand All @@ -225,13 +221,10 @@ async def push_weights(self, policy_version: int) -> None:
param_names=hf_state_dict.keys(),
)
await ts.put(key, dcp_handle)
t.step("dcp_save")
else:
for name, param in hf_state_dict.items():
key = get_param_key(policy_version, name)
await ts.put(key, param)
t.step("ts_save")
t.stop()
end_time = time.perf_counter()
logger.info("Completed weights push in %.2f seconds", end_time - start_time)

Expand Down