Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
77488cf
commit
Oct 8, 2025
feb4771
commit
Oct 8, 2025
41ceaa4
update backend role typehints and enum
Oct 8, 2025
8a24e71
update where we check FORGE_DISABLE_METRICS
Oct 8, 2025
3f3bc51
remove protected import
Oct 8, 2025
d82c354
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
4fe2611
protect import
Oct 8, 2025
8759bc8
Merge branch 'timestamp_logging_diff1' into timestamp_logging_diff2
Oct 8, 2025
fbb4a9e
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
d81a4ed
record_metric uses dataclass Metric
Oct 8, 2025
1e2255d
commit
Oct 8, 2025
a94c612
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 8, 2025
5b477e8
commit
Oct 9, 2025
f2b3eed
commit
Oct 9, 2025
471b88a
revert
Oct 9, 2025
1a02784
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
fa4895f
remove unnecessary code
Oct 9, 2025
7bb1fe7
better logging
Oct 9, 2025
43d5d27
docs/names
Oct 9, 2025
c97eb98
Merge branch 'timestamp_logging_diff2_5' into timestamp_logging_diff3
Oct 9, 2025
75355a2
commit
Oct 9, 2025
70e9c67
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 9, 2025
12f77c9
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 9, 2025
1186aec
update cfg back to true
Oct 9, 2025
a02ea75
Merge branch 'main' of https://github.com/meta-pytorch/forge into tim…
Oct 13, 2025
aa00898
Merge branch 'timestamp_logging_diff3' into timestamp_logging_diff4
Oct 13, 2025
b75aa31
tests pass
Oct 13, 2025
192d32e
Merge branch 'main' of https://github.com/meta-pytorch/forge into sft…
Oct 13, 2025
57877da
main works
Oct 13, 2025
2bd3b35
docs and naming
Oct 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_1_7b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_32b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ rollout_threads: 1 # Recommended to set equal to policy.num_replicas
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
9 changes: 5 additions & 4 deletions apps/grpo/qwen3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ off_by_n: 1 # Off by one by default
# Observability configuration
metric_logging:
wandb:
project: "grpo-training"
group: "grpo_exp_${oc.env:USER}"
reduce_across_ranks: True
project: grpo-training
group: grpo_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False
console:
reduce_across_ranks: True
logging_mode: global_reduce

# Dataset configuration
dataset:
Expand Down
7 changes: 7 additions & 0 deletions apps/sft_v2/llama3_8b.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,13 @@ activation_checkpoint:
mode: selective
selective_ac_option: op

metric_logging:
wandb:
project: sft-training
group: sft_exp_${oc.env:USER}
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
per_rank_share_run: False

# profiling:
# enable_profiling: false

Expand Down
55 changes: 52 additions & 3 deletions apps/sft_v2/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import math
import os
import sys
import warnings
from functools import partial
from typing import Any

Expand All @@ -28,6 +29,7 @@
from forge.data.datasets.packed import PackedDataset, TextPacker
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
from forge.data.tokenizer import HuggingFaceModelTokenizer
from forge.observability import get_or_create_metric_logger, record_metric, Reduce

from monarch.actor import current_rank, current_size, endpoint
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -109,9 +111,20 @@ def _init_dist(self):
os.environ.update(env)
logger.info("env: {}".format(env))

async def setup_metric_logger(self):
"""Initialization happens in the main process. Here we just retrieve it"""
mlogger = await get_or_create_metric_logger()
return mlogger

def record_batch_metrics(self, data_metrics: list):
"""Record dataset metrics using the observability system."""
for metric in data_metrics:
record_metric(metric.key, metric.value, metric.reduction)

@endpoint
async def setup(self):
self.train_dataloader = self.setup_data()
self.mlogger = await self.setup_metric_logger()
# self.train_dataloader = self.setup_data(
# self.train_config.train_dataset_config,
# self.train_config.train_dataloader_config,
Expand Down Expand Up @@ -235,6 +248,7 @@ def train_step(self, batch) -> None:
labels = batch.pop("labels")
loss = self.forward_backward(batch, labels)

record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
# self.pbar.update(1)
Expand All @@ -251,14 +265,27 @@ async def train(self) -> None:

while self.current_step < self.num_training_steps:
batch = next(dataloader)

# Pop and record metrics from batch before moving to device
self.record_batch_metrics(batch.pop("metrics", []))
record_metric(
"ForgeSFTRecipe/train_step/step", self.current_step, Reduce.MEAN
)

# Move tensors to the appropriate device
for k, v in batch.items():
if isinstance(v, torch.Tensor):
batch[k] = v.to("cuda") # TODO: hardcoded for now

self.train_step(batch)
# self.profiler.step()
self.current_step += 1

# Flush metrics
if self._rank == 0:
logger.debug(f"Flushing metrics at step {self.current_step}")
await self.mlogger.flush.call_one(global_step=self.current_step)

self.checkpointer.save(
curr_step=self.current_step,
last_step=self.current_step == self.num_training_steps,
Expand All @@ -270,16 +297,35 @@ async def train(self) -> None:
async def cleanup(self) -> None:
if self.checkpointer:
self.checkpointer.close()
if self.metric_logger:
self.metric_logger.close()
if hasattr(self, "mlogger") and self.mlogger:
await self.mlogger.shutdown.call_one()

def __repr__(self) -> str:
return "Trainer"


async def run(cfg: DictConfig) -> None:
logging.info("Spawing recipe...")

# TODO (allenwang28) Required for metric logging to work. Should be removed when V1 becomes default
MONARCH_HOSTMESH_V1 = os.getenv("MONARCH_HOSTMESH_V1")
if MONARCH_HOSTMESH_V1 != "1":
warnings.warn(
"MONARCH_HOSTMESH_V1 is set to {MONARCH_HOSTMESH_V1}. Setting it to '1' for SFT v2 to work properly. ",
UserWarning,
stacklevel=2,
)
os.environ["MONARCH_HOSTMESH_V1"] = "1"

logging.info("Spawning recipe...")
process_cfg = cfg.pop("processes")

# Initialize metric logger in main process
metric_logging_cfg = cfg.get(
"metric_logging", {"console": {"logging_mode": "global_reduce"}}
)
mlogger = await get_or_create_metric_logger(process_name="Controller")
await mlogger.init_backends.call_one(metric_logging_cfg)

recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)

logging.info("Created recipe, running setup.")
Expand All @@ -290,6 +336,9 @@ async def run(cfg: DictConfig) -> None:

logging.info("Done training. Clean up")
await recipe.cleanup.call()

# Shutdown metric logger
await mlogger.shutdown.call_one()
await recipe.mesh.stop()
logging.info("All done!")

Expand Down
8 changes: 7 additions & 1 deletion src/forge/data/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from .collate import collate_packed
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
from .utils import CROSS_ENTROPY_IGNORE_IDX

__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"]
__all__ = [
"collate_packed",
"CROSS_ENTROPY_IGNORE_IDX",
"MetricTransform",
"DefaultDatasetMetricTransform",
]
39 changes: 0 additions & 39 deletions src/forge/data/dataset_metrics/__init__.py

This file was deleted.

Loading