Skip to content
Closed
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions apps/grpo/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,7 @@ async def main(cfg: DictConfig):

# initialize before spawning services
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

# ---- Setup services ---- #
await ts.initialize(strategy=ts.ControllerStorageVolumes())
Expand Down Expand Up @@ -363,6 +362,8 @@ async def main(cfg: DictConfig):
),
)

await mlogger.init_backends.call_one(metric_logging_cfg)

print("All services initialized successfully!")

# ---- Core RL loops ---- #
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas


# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@ off_by_n: 1 # Off by one by default
rollout_threads: 1 # Recommended to set equal to policy.num_replicas

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
6 changes: 4 additions & 2 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,15 @@ model: "Qwen/Qwen3-8B"
off_by_n: 1 # Off by one by default

# Observability configuration
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
logging_mode: "global_reduce"
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: "global_reduce"

# Dataset configuration
dataset:
Expand Down
19 changes: 11 additions & 8 deletions apps/toy_rl/toy_metrics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from monarch.actor import current_rank, endpoint

logging.basicConfig(level=logging.DEBUG)
logging.getLogger("forge.observability.metrics").setLevel(logging.DEBUG)


class TrainActor(ForgeActor):
Expand Down Expand Up @@ -82,31 +83,33 @@ async def main():
group = f"grpo_exp_{int(time.time())}"

# Config format: {backend_name: backend_config_dict}
# Each backend can specify reduce_across_ranks to control distributed logging behavior
config = {
"console": {"reduce_across_ranks": True},
"console": {"logging_mode": "per_rank_reduce"},
"wandb": {
"project": "my_project",
"project": "immediate_logging_test",
"group": group,
"reduce_across_ranks": False,
# Only useful if NOT reduce_across_ranks.
"share_run_id": False, # Share run ID across ranks -- Not recommended.
"logging_mode": "per_rank_no_reduce",
"per_rank_share_run": False,
},
}

service_config = {"procs": 2, "num_replicas": 2, "with_gpus": False}
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(config)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

# Spawn services first (triggers registrations via provisioner hook)
trainer = await TrainActor.options(**service_config).as_service()
generator = await GeneratorActor.options(**service_config).as_service()

# Initialize after spawning services
await mlogger.init_backends.call_one(config)

for i in range(3):
print(f"\n=== Global Step {i} ===")
record_metric("main/global_step", 1, Reduce.MEAN)
await trainer.train_step.fanout(i)
for sub in range(3):
await generator.generate_step.fanout(i, sub)
await asyncio.sleep(0.1)
await mlogger.flush.call_one(i)

# shutdown
Expand Down
6 changes: 4 additions & 2 deletions apps/vllm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@

async def run(cfg: DictConfig):
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
mlogger = await get_or_create_metric_logger()
await mlogger.init_backends.call_one(metric_logging_cfg)
mlogger = await get_or_create_metric_logger(actor_name="Controller")

if (prompt := cfg.get("prompt")) is None:
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
Expand All @@ -37,6 +36,9 @@ async def run(cfg: DictConfig):
print("Spawning service...")
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)

# initialize after spawning services
await mlogger.init_backends.call_one(metric_logging_cfg)

import time

print("Requesting generation...")
Expand Down
6 changes: 4 additions & 2 deletions src/forge/controller/provisioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from forge.controller.launcher import BaseLauncher, get_launcher

from forge.observability.metric_actors import get_or_create_metric_logger
from forge.observability.utils import detect_actor_name_from_call_stack

from forge.types import ProcessConfig, ProvisionerConfig

Expand Down Expand Up @@ -201,8 +202,9 @@ def bootstrap(gpu_ids: list[str]):
self._server_names.append(server_name)
self._proc_server_map[procs] = server_name

# Spawn local logging actor on each process and register with global logger
_ = await get_or_create_metric_logger(procs)
# Detect actor name and spawn local logging actor on each process
actor_name = detect_actor_name_from_call_stack()
_ = await get_or_create_metric_logger(procs, actor_name=actor_name)

return procs

Expand Down
2 changes: 1 addition & 1 deletion src/forge/env_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

# Force all timing methods in forge.observability.perf_tracker.py to use
# CPU timer if False or GPU timer if True. If unset, defaults to the assigned value to the function.
METRIC_TIMER_USES_CUDA = "METRIC_TIMER_USES_CUDA"
METRIC_TIMER_USES_GPU = "METRIC_TIMER_USES_GPU"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

making it future proof when we support other backends besides cuda

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TPUs here we come

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OH NO

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"METRIC_TIMER_USES_ACCELERATOR"? Is this what torch uses? geez, we refactor it when the time comes


# Makes forge.observability.metrics.record_metric a no-op
FORGE_DISABLE_METRICS = "FORGE_DISABLE_METRICS"
10 changes: 7 additions & 3 deletions src/forge/observability/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,20 +11,20 @@
)
from .metrics import (
ConsoleBackend,
# Utility functions
get_actor_name_with_rank,
get_logger_backend_class,
# Backend classes
LoggerBackend,
LoggingMode,
MaxAccumulator,
MeanAccumulator,
# Accumulator classes
Metric,
MetricAccumulator,
MetricCollector,
MinAccumulator,
record_metric,
Reduce,
reduce_metrics_states,
Role,
StdAccumulator,
SumAccumulator,
WandbBackend,
Expand All @@ -41,8 +41,12 @@
# Performance tracking
"Tracer",
"trace",
# Data classes
"Metric",
"Role",
# Enums
"Reduce",
"LoggingMode",
# Actor classes
"GlobalLoggingActor",
"LocalFetcherActor",
Expand Down
Loading
Loading