Skip to content

Commit 4ac667a

Browse files
author
Felipe Mello
committed
update configs
1 parent abc6447 commit 4ac667a

File tree

6 files changed

+20
-13
lines changed

6 files changed

+20
-13
lines changed

apps/grpo/main.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -334,7 +334,7 @@ async def main(cfg: DictConfig):
334334

335335
# initialize before spawning services
336336
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
337-
mlogger = await get_or_create_metric_logger()
337+
mlogger = await get_or_create_metric_logger(actor_name="Controller")
338338

339339
# ---- Setup services ---- #
340340
await ts.initialize(strategy=ts.ControllerStorageVolumes())

apps/grpo/qwen3_1_7b.yaml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas
1414

1515

1616
# Observability configuration
17+
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
1718
metric_logging:
1819
wandb:
19-
project: "test"
20+
project: "grpo-training"
2021
group: "grpo_exp_${oc.env:USER}"
21-
logging_mode: "per_rank_no_reduce"
22-
"per_rank_share_run": True
22+
logging_mode: "global_reduce"
23+
per_rank_share_run: False
2324
console:
24-
logging_mode: "per_rank_no_reduce"
25+
logging_mode: "global_reduce"
2526

2627
# Dataset configuration
2728
dataset:

apps/grpo/qwen3_32b.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,15 @@ off_by_n: 1 # Off by one by default
1414
rollout_threads: 1 # Recommended to set equal to policy.num_replicas
1515

1616
# Observability configuration
17+
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
1718
metric_logging:
1819
wandb:
1920
project: "grpo-training"
2021
group: "grpo_exp_${oc.env:USER}"
21-
reduce_across_ranks: True
22+
logging_mode: "global_reduce"
23+
per_rank_share_run: False
2224
console:
23-
reduce_across_ranks: True
25+
logging_mode: "global_reduce"
2426

2527
# Dataset configuration
2628
dataset:

apps/grpo/qwen3_8b.yaml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,15 @@ model: "Qwen/Qwen3-8B"
1010
off_by_n: 1 # Off by one by default
1111

1212
# Observability configuration
13+
# logging_mode: global_reduce, per_rank_reduce, per_rank_no_reduce
1314
metric_logging:
1415
wandb:
1516
project: "grpo-training"
1617
group: "grpo_exp_${oc.env:USER}"
17-
reduce_across_ranks: True
18+
logging_mode: "global_reduce"
19+
per_rank_share_run: False
1820
console:
19-
reduce_across_ranks: True
21+
logging_mode: "global_reduce"
2022

2123
# Dataset configuration
2224
dataset:

apps/vllm/main.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@
2727

2828
async def run(cfg: DictConfig):
2929
metric_logging_cfg = cfg.get("metric_logging", {"console": {"log_per_rank": False}})
30-
mlogger = await get_or_create_metric_logger()
31-
await mlogger.init_backends.call_one(metric_logging_cfg)
30+
mlogger = await get_or_create_metric_logger(actor_name="Controller")
3231

3332
if (prompt := cfg.get("prompt")) is None:
3433
gd = cfg.policy.get("sampling_config", {}).get("guided_decoding", False)
@@ -37,6 +36,9 @@ async def run(cfg: DictConfig):
3736
print("Spawning service...")
3837
policy = await Policy.options(**cfg.services.policy).as_service(**cfg.policy)
3938

39+
# initialize after spawning services
40+
await mlogger.init_backends.call_one(metric_logging_cfg)
41+
4042
import time
4143

4244
print("Requesting generation...")

src/forge/observability/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,10 @@
1515

1616
import pytz
1717

18-
from monarch.actor import current_rank
19-
2018
from forge.observability.utils import get_actor_name_with_rank
2119

20+
from monarch.actor import current_rank
21+
2222
logger = logging.getLogger(__name__)
2323

2424

0 commit comments

Comments
 (0)