Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
0f63c4e
add time stamp logging
Oct 3, 2025
f18e9a0
delete file
Oct 3, 2025
f2aa103
review pass
Oct 3, 2025
0aa9e15
nits, tests and linter
Oct 3, 2025
412c453
nit + update env flag
Oct 3, 2025
8ebcc6d
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 3, 2025
0e6a549
delete file + update cfg
Oct 3, 2025
abc6447
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 5, 2025
4ac667a
update configs
Oct 5, 2025
c7c34aa
lint
Oct 5, 2025
372862d
reutilize reduce_metrics_states
Oct 5, 2025
db27d86
change method name
Oct 5, 2025
9d2debf
rename + docstrings
Oct 6, 2025
504d7e1
add comment
Oct 6, 2025
ec86741
update comments
Oct 6, 2025
8037b7a
not initing backends will raise warning instead of breaking
Oct 6, 2025
292d018
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 6, 2025
715c74d
delete file
Oct 6, 2025
83e63b5
config nit
Oct 6, 2025
7edf942
sort prints
Oct 6, 2025
6a28f9e
rename arg
Oct 6, 2025
f21afb7
more arg names
Oct 6, 2025
60e6382
more arg names
Oct 6, 2025
25caeb0
fix wandb hang
Oct 7, 2025
24a5e96
add unit tet for step count
Oct 7, 2025
b726b00
change step -> global_step
Oct 7, 2025
a297090
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 7, 2025
5535eb6
change toy config
Oct 7, 2025
ece12d7
remove comment
Oct 7, 2025
8f1342c
implement samplers
DNXie Oct 3, 2025
0070123
add accumulator
DNXie Oct 4, 2025
a45c075
integrate sampling
DNXie Oct 4, 2025
ed8f50e
update init
DNXie Oct 5, 2025
52ea47d
debug; blocked by wandb table upload bug
DNXie Oct 5, 2025
88a543a
a working version
DNXie Oct 6, 2025
62278a6
a working version
DNXie Oct 6, 2025
4478642
fix ci, lint, add few test cases
DNXie Oct 6, 2025
e859f3d
resolve comments
DNXie Oct 6, 2025
8ede0cb
simplify sampleAccumulator
DNXie Oct 6, 2025
dfba33b
resolve comments2
DNXie Oct 6, 2025
6380f13
fix ci
DNXie Oct 6, 2025
1736df1
more readable
DNXie Oct 7, 2025
e0c6f33
add sample log for toy_rl/metrics; support sample log for log_stream
DNXie Oct 7, 2025
aab8dc5
use incremental table
DNXie Oct 7, 2025
487c01e
fix import error
DNXie Oct 7, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 13 additions & 5 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
from forge.controller.provisioner import init_provisioner, shutdown
from forge.data.rewards import MathReward, ThinkingReward
from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.metrics import record_metric, Reduce
from forge.observability.metrics import record_episode_sample, record_metric, Reduce
from forge.observability.perf_tracker import Tracer

from forge.types import LauncherConfig, ProvisionerConfig
Expand All @@ -55,6 +55,7 @@ class Episode:
response_tokens: list[int] | None = None
ref_logprobs: torch.Tensor | None = None
reward: float | None = None
reward_breakdown: dict[str, float] | None = None
advantage: float | None = None

@property
Expand Down Expand Up @@ -169,8 +170,11 @@ class RewardActor(ForgeActor):
reward_functions: list[Callable]

@endpoint
async def evaluate_response(self, prompt: str, response: str, target: str) -> float:
async def evaluate_response(
self, prompt: str, response: str, target: str
) -> dict[str, float]:
total_rewards = 0.0
reward_breakdown = {} # reward breakdown by function
for reward_fn in self.reward_functions:
reward = reward_fn(prompt, response, target)
total_rewards += reward
Expand All @@ -179,6 +183,7 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
reward_fn_name = getattr(
reward_fn, "__name__", reward_fn.__class__.__name__
)
reward_breakdown[reward_fn_name] = reward
# per function reward
record_metric(
f"reward/evaluate_response/sum_{reward_fn_name}_reward",
Expand Down Expand Up @@ -211,7 +216,8 @@ async def evaluate_response(self, prompt: str, response: str, target: str) -> fl
)

avg_reward = total_rewards / len(self.reward_functions)
return avg_reward
reward_breakdown["reward"] = avg_reward
return reward_breakdown


@dataclass
Expand Down Expand Up @@ -321,7 +327,7 @@ async def main(cfg: DictConfig):
)
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)
await ts.initialize(strategy=ts.ControllerStorageVolumes())

Expand Down Expand Up @@ -403,9 +409,10 @@ async def continuous_rollouts():
episode.response = response.text
input_ids[i, :max_req_tokens] = episode.request_tensor
input_ids[i, max_req_tokens:] = episode.response_tensor
episode.reward = await reward_actor.evaluate_response.route(
episode.reward_breakdown = await reward_actor.evaluate_response.route(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i like this, but i think it can be a bit dangerous. We dont have a dataclass that says the fields that it will hold. You would also need to make sure that the other actors are aware of this change. I am thinking that maybe we should keep episode.reward: float and add an extra optional field episode.reward_breakdown: dict[float]. Wdyt?

Copy link
Member Author

@DNXie DNXie Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Exactly! it comes with two field reward and reward_breakdown. If you look at the line below it:

episode.reward = episode.reward_breakdown["reward"]

prompt=prompt, response=response.text, target=target
)
episode.reward = episode.reward_breakdown["reward"]

t.step("reward_evaluation")

Expand All @@ -424,6 +431,7 @@ async def continuous_rollouts():
for episode, advantage in zip(group.episodes, advantages):
episode.advantage = advantage
await replay_buffer.add.call_one(episode)
record_episode_sample("rollout/sample", episode)

# Log metrics
rollout_count += 1
Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ off_by_n: 1 # Off by one by default
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
7 changes: 5 additions & 2 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from forge.controller.launcher import BaseLauncher, get_launcher

from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.utils import detect_actor_name_from_call_stack

from forge.types import ProcessConfig, ProvisionerConfig

Expand Down Expand Up @@ -259,8 +260,10 @@ def bootstrap(env: dict[str, str]):

self._proc_host_map[procs] = host_mesh

# Spawn local logging actor on each process and register with global logger
_ = await get_or_create_metric_logger(procs)
# Detect actor name and spawn local logging actor on each process
process_name = detect_actor_name_from_call_stack()
_ = await get_or_create_metric_logger(procs, process_name=process_name)

return procs

async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
Expand Down
2 changes: 1 addition & 1 deletion src/forge/env_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Force all timing methods in forge.observability.perf_tracker.py to use
# CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function.
METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA"
METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU"

# Makes forge.observability.metrics.record_metric a no-op
FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS"
17 changes: 14 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,39 +10,47 @@
LocalFetcherActor,
)
from .metrics import (
BackendRole,
ConsoleBackend,
# Utility functions
get_actor_name_with_rank,
get_logger_backend_class,
# Backend classes
LoggerBackend,
LoggingMode,
MaxAccumulator,
MeanAccumulator,
# Accumulator classes
Metric,
MetricAccumulator,
MetricCollector,
MinAccumulator,
record_episode_sample,
record_metric,
Reduce,
reduce_metrics_states,
SampleAccumulator,
StdAccumulator,
SumAccumulator,
TopBottomKFilter,
WandbBackend,
)
from .perf_tracker import trace, Tracer

__all__ = [
# Main API functions
"record_metric",
"record_episode_sample",
"reduce_metrics_states",
"get_actor_name_with_rank",
"get_logger_backend_class",
"get_or_create_metric_logger",
# Performance tracking
"Tracer",
"trace",
# Data classes
"Metric",
"BackendRole",
# Enums
"Reduce",
"LoggingMode",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand All @@ -59,4 +67,7 @@
"MaxAccumulator",
"MinAccumulator",
"StdAccumulator",
"SampleAccumulator",
# Filter classes
"TopBottomKFilter",
]
Loading
Loading