Skip to content

Commit feb4865

Browse files
felipemello1Felipe Mello
andauthored
[Metric logging] log config.yaml (#605)
Co-authored-by: Felipe Mello <[email protected]>
1 parent 47c2333 commit feb4865

File tree

3 files changed

+81
-28
lines changed

3 files changed

+81
-28
lines changed

apps/grpo/main.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import torch
1616
import torch.nn.functional as F
1717
import torchstore as ts
18+
import yaml
1819
from datasets import load_dataset
1920
from forge.actors._torchstore_utils import (
2021
get_dcp_whole_state_dict_key,
@@ -33,11 +34,14 @@
3334
from forge.observability.perf_tracker import Tracer
3435
from forge.types import LauncherConfig, ProvisionerConfig
3536
from forge.util.config import parse
37+
from forge.util.logging import get_logger
3638
from forge.util.ops import compute_logprobs
3739
from monarch.actor import endpoint
38-
from omegaconf import DictConfig
40+
from omegaconf import DictConfig, OmegaConf
3941
from vllm.transformers_utils.tokenizer import get_tokenizer
4042

43+
logger = get_logger("INFO")
44+
4145

4246
@dataclass
4347
class Episode:
@@ -358,9 +362,14 @@ async def drop_weights(version: int):
358362

359363
async def main(cfg: DictConfig):
360364
"""Main GRPO training loop with rollout and training processes."""
361-
group_size = cfg.group_size
362-
max_req_tokens = cfg.max_req_tokens
363-
max_res_tokens = cfg.max_res_tokens
365+
# Convert OmegaConf config to plain dict
366+
run_config_for_logging = OmegaConf.to_container(cfg, resolve=True)
367+
368+
# Log config
369+
logger.info("=" * 30 + " CONFIGURATION " + "=" * 30)
370+
logger.info(
371+
yaml.dump(run_config_for_logging, default_flow_style=False, sort_keys=False)
372+
)
364373

365374
# ---- Global setups ---- #
366375
provisioner = None
@@ -372,8 +381,11 @@ async def main(cfg: DictConfig):
372381
provisioner = await init_provisioner()
373382

374383
metric_logging_cfg = cfg.get("metric_logging", {})
384+
375385
mlogger = await get_or_create_metric_logger(process_name="Controller")
376-
await mlogger.init_backends.call_one(metric_logging_cfg)
386+
await mlogger.init_backends.call_one(
387+
backend_config=metric_logging_cfg, run_config=run_config_for_logging
388+
)
377389

378390
# ---- Setup services ---- #
379391

@@ -411,6 +423,10 @@ async def main(cfg: DictConfig):
411423
),
412424
)
413425

426+
group_size = cfg.group_size
427+
max_req_tokens = cfg.max_req_tokens
428+
max_res_tokens = cfg.max_res_tokens
429+
414430
# Set max_steps to the configured value, or -1 if not specified or Null
415431
max_steps = cfg.trainer.training.steps or -1
416432

src/forge/observability/metric_actors.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -169,22 +169,27 @@ async def flush(
169169
async def init_backends(
170170
self,
171171
metadata_per_controller_backend: dict[str, dict[str, Any]],
172-
config: dict[str, Any],
172+
backend_config: dict[str, Any],
173+
run_config: dict[str, Any] | None = None,
173174
global_step: int = 0,
174175
) -> None:
175176
"""Init per-rank logger backends and MetricCollector.
176177
177178
Args:
178179
metadata_per_controller_backend (dict[str, dict[str, Any]]): Metadata from controller backends for shared state.
179-
config (dict[str, Any]): Backend configurations with logging modes and settings.
180+
backend_config (dict[str, Any]): Backend configurations with logging modes and settings.
181+
run_config (dict[str, Any] | None): Your application's configuration
182+
(hyperparameters, dataset, model settings) to log to backends for
183+
experiment tracking.
180184
global_step (int): Initial step for metrics.
181185
"""
182186
collector = MetricCollector()
183187
await collector.init_backends(
184188
metadata_per_controller_backend,
185-
config,
189+
backend_config,
186190
global_step,
187191
process_name=self.process_name,
192+
run_config=run_config,
188193
)
189194

190195
@endpoint
@@ -211,6 +216,7 @@ class GlobalLoggingActor(ForgeActor):
211216
def __init__(self):
212217
self.fetchers: dict[str, LocalFetcherActor] = {}
213218
self.config: dict[str, Any] | None = None
219+
self.run_config: dict[str, Any] | None = None
214220
self.global_logger_backends: dict[str, LoggerBackend] = {}
215221
self.metadata_per_controller_backend: dict[str, dict[str, Any]] = {}
216222

@@ -267,15 +273,17 @@ def _validate_backend_config(
267273
}
268274

269275
@endpoint
270-
async def init_backends(self, config: dict[str, Any]) -> None:
276+
async def init_backends(
277+
self, backend_config: dict[str, Any], run_config: dict[str, Any] | None = None
278+
) -> None:
271279
"""Sets config in global actor and initializes existing backends and collectors. Later spawned actors
272280
are initialized in `register_fetcher` endpoint.
273281
274282
Controller backends (instantiated in the controller) can provide metadata to be shared with rank backends,
275283
e.g. shared run IDs for WandB. For details on logging modes, see `forge.observability.metrics.LoggingMode`.
276284
277285
Args:
278-
config (dict[str, Any]): Config for metric logging where keys are backend names.
286+
backend_config (dict[str, Any]): Config for metric logging where keys are backend names.
279287
Each backend config supports:
280288
- logging_mode (str | LoggingMode): Check LoggingMode for options. Defaults to "global_reduce".
281289
- per_rank_share_run (bool, default False): For per-rank modes only. Whether ranks
@@ -291,21 +299,23 @@ async def init_backends(self, config: dict[str, Any]) -> None:
291299
"project": "my_project",
292300
}
293301
}
302+
run_config (dict[str, Any] | None): Your application's configuration
303+
(hyperparameters, dataset, model settings) to log to backends for
304+
experiment tracking.
294305
295306
Raises:
296307
ValueError: If backend config is invalid or missing required fields.
297308
"""
298309
self.config = {}
310+
self.run_config = run_config
299311

300312
# Skip initialization if disabled by environment flag
301313
if FORGE_DISABLE_METRICS.get_value():
302314
return
303315

304316
# Validate and normalize each backend config
305-
for backend_name, backend_config in config.items():
306-
self.config[backend_name] = self._validate_backend_config(
307-
backend_name, backend_config
308-
)
317+
for backend_name, cfg in backend_config.items():
318+
self.config[backend_name] = self._validate_backend_config(backend_name, cfg)
309319

310320
# Initialize backends based on logging mode
311321
for backend_name, backend_config in self.config.items():
@@ -314,7 +324,11 @@ async def init_backends(self, config: dict[str, Any]) -> None:
314324
backend: LoggerBackend = get_logger_backend_class(backend_name)(
315325
**backend_config
316326
)
317-
await backend.init(role=BackendRole.GLOBAL, process_name="global_reduce")
327+
await backend.init(
328+
role=BackendRole.GLOBAL,
329+
process_name="global_reduce",
330+
run_config=self.run_config,
331+
)
318332

319333
# Extract metadata from controller logger to be shared with per-rank loggers
320334
if mode != LoggingMode.GLOBAL_REDUCE:
@@ -331,7 +345,7 @@ async def init_backends(self, config: dict[str, Any]) -> None:
331345
if self.fetchers:
332346
tasks = [
333347
fetcher.init_backends.call(
334-
self.metadata_per_controller_backend, self.config
348+
self.metadata_per_controller_backend, self.config, self.run_config
335349
)
336350
for fetcher in self.fetchers.values()
337351
]
@@ -351,7 +365,7 @@ async def register_fetcher(self, fetcher: LocalFetcherActor, proc_id: str) -> No
351365
if self.config:
352366
logger.debug(f"Initializing new LocalFetcherActor for proc_id={proc_id}")
353367
await fetcher.init_backends.call(
354-
self.metadata_per_controller_backend, self.config
368+
self.metadata_per_controller_backend, self.config, self.run_config
355369
)
356370

357371
@endpoint

src/forge/observability/metrics.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -550,9 +550,10 @@ def __init__(self) -> None:
550550
async def init_backends(
551551
self,
552552
metadata_per_controller_backend: dict[str, dict[str, Any]] | None,
553-
config: dict[str, Any],
553+
backend_config: dict[str, Any],
554554
global_step: int = 0,
555555
process_name: str | None = None,
556+
run_config: dict[str, Any] | None = None,
556557
) -> None:
557558
"""Initialize per-rank logger backends and MetricCollector state.
558559
@@ -563,12 +564,15 @@ async def init_backends(
563564
metadata_per_controller_backend (Optional[Dict[str, Dict[str, Any]]]): Metadata from controller
564565
for backends that require shared state across processes, e.g.,
565566
{"wandb": {"shared_run_id": "abc123"}}.
566-
config (Dict[str, Any]): Backend configurations where each key is a backend name
567+
backend_config (Dict[str, Any]): Backend configurations where each key is a backend name
567568
and value contains logging_mode and backend-specific settings.
568569
e.g., {"wandb": {"logging_mode": "per_rank_no_reduce", "project": "my_proj"}}
569570
global_step (int, default 0): Initial step for logging. Can be used when
570571
resuming from a checkpoint.
571572
process_name (str | None): The meaningful process name for logging.
573+
run_config (dict[str, Any] | None): Your application's configuration
574+
(hyperparameters, dataset, model settings) to log to backends for
575+
experiment tracking.
572576
"""
573577
if self._is_initialized:
574578
logger.debug(
@@ -583,8 +587,8 @@ async def init_backends(
583587
self.per_rank_no_reduce_backends: list[LoggerBackend] = []
584588

585589
# Initialize backends based on logging mode
586-
for backend_name, backend_config in config.items():
587-
mode = backend_config["logging_mode"]
590+
for backend_name, cfg in backend_config.items():
591+
mode = cfg["logging_mode"]
588592

589593
# sanity check
590594
if not isinstance(mode, LoggingMode):
@@ -605,13 +609,12 @@ async def init_backends(
605609
)
606610

607611
# instantiate local backend
608-
backend: LoggerBackend = get_logger_backend_class(backend_name)(
609-
**backend_config
610-
)
612+
backend: LoggerBackend = get_logger_backend_class(backend_name)(**cfg)
611613
await backend.init(
612614
role=BackendRole.LOCAL,
613615
controller_logger_metadata=controller_metadata,
614616
process_name=self.proc_name_with_rank,
617+
run_config=run_config,
615618
)
616619

617620
# Categorize by logging mode
@@ -781,6 +784,7 @@ async def init(
781784
role: BackendRole,
782785
controller_logger_metadata: dict[str, Any] | None = None,
783786
process_name: str | None = None,
787+
run_config: dict[str, Any] | None = None,
784788
) -> None:
785789
"""
786790
Initializes backend, e.g. wandb.run.init().
@@ -791,6 +795,9 @@ async def init(
791795
controller_logger_metadata (dict[str, Any] | None): From global backend for
792796
backend that required shared info, e.g. {"shared_run_id": "abc123"}.
793797
process_name (str | None): Process name for logging.
798+
run_config (dict[str, Any] | None): Your application's configuration
799+
(hyperparameters, dataset, model settings) to log to backend for
800+
experiment tracking.
794801
795802
Raises: ValueError if missing metadata for shared local init.
796803
"""
@@ -856,6 +863,7 @@ async def init(
856863
role: BackendRole,
857864
controller_logger_metadata: dict[str, Any] | None = None,
858865
process_name: str | None = None,
866+
run_config: dict[str, Any] | None = None,
859867
) -> None:
860868
self.process_name = process_name
861869

@@ -927,13 +935,15 @@ async def init(
927935
role: BackendRole,
928936
controller_logger_metadata: dict[str, Any] | None = None,
929937
process_name: str | None = None,
938+
run_config: dict[str, Any] | None = None,
930939
) -> None:
931940
if controller_logger_metadata is None:
932941
controller_logger_metadata = {}
933942

934943
# Pop name, if any, to concat to process_name.
935944
run_name = self.backend_kwargs.pop("name", None)
936945
self.process_name = process_name
946+
self.run_config = run_config
937947

938948
# Format run name based on mode and role
939949
if self.logging_mode == LoggingMode.GLOBAL_REDUCE:
@@ -964,20 +974,29 @@ async def init(
964974
async def _init_global(self, run_name: str | None):
965975
import wandb
966976

967-
self.run = wandb.init(name=run_name, **self.backend_kwargs)
977+
self.run = wandb.init(
978+
name=run_name, config=self.run_config, **self.backend_kwargs
979+
)
968980

969981
async def _init_per_rank(self, run_name: str):
970982
import wandb
971983

972-
self.run = wandb.init(name=run_name, **self.backend_kwargs)
984+
self.run = wandb.init(
985+
name=run_name, config=self.run_config, **self.backend_kwargs
986+
)
973987

974988
async def _init_shared_global(self, run_name: str | None):
975989
import wandb
976990

977991
settings = wandb.Settings(
978992
mode="shared", x_primary=True, x_label="controller_primary"
979993
)
980-
self.run = wandb.init(name=run_name, settings=settings, **self.backend_kwargs)
994+
self.run = wandb.init(
995+
name=run_name,
996+
config=self.run_config,
997+
settings=settings,
998+
**self.backend_kwargs,
999+
)
9811000

9821001
async def _init_shared_local(
9831002
self, run_name: str, shared_id: str, process_name: str
@@ -994,7 +1013,11 @@ async def _init_shared_local(
9941013

9951014
settings = wandb.Settings(mode="shared", x_primary=False, x_label=process_name)
9961015
self.run = wandb.init(
997-
name=run_name, id=shared_id, settings=settings, **self.backend_kwargs
1016+
name=run_name,
1017+
id=shared_id,
1018+
config=self.run_config,
1019+
settings=settings,
1020+
**self.backend_kwargs,
9981021
)
9991022

10001023
async def log_batch(

0 commit comments

Comments
 (0)