Skip to content
Closed
Show file tree
Hide file tree
Changes from 18 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,8 +321,7 @@ async def main(cfg: DictConfig):
)
)
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")
await ts.initialize(strategy=ts.ControllerStorageVolumes())

# ---- Setup services ---- #
Expand Down Expand Up @@ -351,6 +350,11 @@ async def main(cfg: DictConfig):
),
)

# Call after services are initialized
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you maybe explain in the comment, why the init_backends should be called after services are initialized?

Copy link
Contributor Author

@felipemello1 felipemello1 Oct 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calling before works for every mode except when 'per_rank_share_run=True'. Then it hangs. wandb says its experimental, and it didnt investigate it more deeply to see if i need to wait for something to finish. But i agree, i will add a note! Edit: done

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we debug this further instead of checking in this workaround?

# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
# probably wandb requires primary runs to finish before shared runs can be initialized
await mlogger.init_backends.call_one(metric_logging_cfg)

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas


# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,15 @@ provisioner:
rollout_threads: 1 # Recommended to set equal to policy.num_replicas

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
21 changes: 13 additions & 8 deletions apps/toy_rl/toy_metrics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monarch.actor import current_rank, endpoint

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("forge.observability.metrics").setLevel(logging.DEBUG)


class TrainActor(ForgeActor):
Expand Down Expand Up @@ -82,31 +83,35 @@ async def main():
group = f"grpo_exp_{int(time.time())}"

# Config format: {backend_name: backend_config_dict}
# Each backend can specify reduce_across_ranks to control distributed logging behavior
config = {
"console": {"reduce_across_ranks": True},
"console": {"logging_mode": "per_rank_reduce"},
"wandb": {
"project": "my_project",
"project": "toy_metrics",
"group": group,
"reduce_across_ranks": False,
# Only useful if NOT reduce_across_ranks.
"share_run_id": False, # Share run ID across ranks -- Not recommended.
"logging_mode": "per_rank_no_reduce",
"per_rank_share_run": False,
},
}

service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(config)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

# Spawn services first (triggers registrations via provisioner hook)
trainer = await TrainActor.options(**service_config).as_service()
generator = await GeneratorActor.options(**service_config).as_service()

# Call after services are initialized
# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
# probably wandb requires primary runs to finish before shared runs can be initialized
await mlogger.init_backends.call_one(config)

for i in range(3):
print(f"\n=== Global Step {i} ===")
record_metric("main/global_step", 1, Reduce.MEAN)
await trainer.train_step.fanout(i)
for sub in range(3):
await generator.generate_step.fanout(i, sub)
await asyncio.sleep(0.1)
await mlogger.flush.call_one(i)

# shutdown
Expand Down
8 changes: 6 additions & 2 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

async def run(cfg: DictConfig):
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
Expand All @@ -37,6 +36,11 @@ async def run(cfg: DictConfig):
print("Spawning service...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)

# Call after services are initialized
# TODO (felipemello): if called before, and per_rank_share_run=True, it hangs
# probably wandb requires primary runs to finish before shared runs can be initialized
await mlogger.init_backends.call_one(metric_logging_cfg)

import time

print("Requesting generation...")
Expand Down
7 changes: 5 additions & 2 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from forge.controller.launcher import BaseLauncher, get_launcher

from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.utils import detect_actor_name_from_call_stack

from forge.types import ProcessConfig, ProvisionerConfig

Expand Down Expand Up @@ -253,8 +254,10 @@ def bootstrap(env: dict[str, str]):

self._proc_host_map[procs] = host_mesh

# Spawn local logging actor on each process and register with global logger
_ = await get_or_create_metric_logger(procs)
# Detect actor name and spawn local logging actor on each process
actor_name = detect_actor_name_from_call_stack()
_ = await get_or_create_metric_logger(procs, actor_name=actor_name)

return procs

async def host_mesh_from_proc(self, proc_mesh: ProcMesh):
Expand Down
2 changes: 1 addition & 1 deletion src/forge/env_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Force all timing methods in forge.observability.perf_tracker.py to use
# CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function.
METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA"
METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making it future proof when we support other backends besides cuda

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TPUs here we come

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OH NO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"METRIC_TIMER_USES_ACCELERATOR"? Is this what torch uses? geez, we refactor it when the time comes


# Makes forge.observability.metrics.record_metric a no-op
FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS"
10 changes: 7 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
)
from .metrics import (
ConsoleBackend,
# Utility functions
get_actor_name_with_rank,
get_logger_backend_class,
# Backend classes
LoggerBackend,
LoggingMode,
MaxAccumulator,
MeanAccumulator,
# Accumulator classes
Metric,
MetricAccumulator,
MetricCollector,
MinAccumulator,
record_metric,
Reduce,
reduce_metrics_states,
Role,
StdAccumulator,
SumAccumulator,
WandbBackend,
Expand All @@ -41,8 +41,12 @@
# Performance tracking
"Tracer",
"trace",
# Data classes
"Metric",
"Role",
# Enums
"Reduce",
"LoggingMode",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand Down
Loading
Loading