Skip to content

Commit c2049fc

Browse files
authored
Revert "Metric Logging updates 6/N - Enable SFT metrics / delete old file" (meta-pytorch#489)
1 parent 5a533b1 commit c2049fc

22 files changed

+1870
-342
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ parallelism:
4646
checkpoint:
4747
enable: true
4848
folder: ./checkpoint # The folder to save checkpoints to.
49-
initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
49+
initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
5050
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
5151
last_save_in_hf: true
5252
interval: 500
@@ -56,12 +56,6 @@ activation_checkpoint:
5656
mode: selective
5757
selective_ac_option: op
5858

59-
metric_logging:
60-
wandb:
61-
project: sft-training
62-
group: sft_exp_${oc.env:USER}
63-
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
64-
6559
# profiling:
6660
# enable_profiling: false
6761

apps/sft/main.py

Lines changed: 4 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from forge.data.datasets.packed import PackedDataset, TextPacker
2828
from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset
2929
from forge.data.tokenizer import HuggingFaceModelTokenizer
30-
from forge.observability import get_or_create_metric_logger, record_metric, Reduce
3130
from forge.util.config import parse
3231

3332
from monarch.actor import current_rank, current_size, endpoint
@@ -78,6 +77,7 @@ def __init__(self, config: DictConfig):
7877

7978
self.current_step = 0
8079
self.num_training_steps = job_config.training.steps
80+
self.metric_logger = None # TODO: fix this
8181
self.gradient_accumulation_steps = 1 # Example value, adjust as needed
8282
self._rank = current_rank().rank
8383
self._size = math.prod(current_size().values())
@@ -109,22 +109,9 @@ def _init_dist(self):
109109
os.environ.update(env)
110110
logger.info("env: {}".format(env))
111111

112-
async def setup_metric_logger(self):
113-
"""Initialization happens in the main process. Here we just retrieve it"""
114-
mlogger = await get_or_create_metric_logger()
115-
return mlogger
116-
117-
def record_batch_metrics(self, data_metrics: list):
118-
"""Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
119-
Instead, pop the metrics from the batch and record them here."""
120-
for metric in data_metrics:
121-
record_metric(metric.key, metric.value, metric.reduction)
122-
123112
@endpoint
124113
async def setup(self):
125114
self.train_dataloader = self.setup_data()
126-
self.mlogger = await self.setup_metric_logger()
127-
128115
# self.train_dataloader = self.setup_data(
129116
# self.train_config.train_dataset_config,
130117
# self.train_config.train_dataloader_config,
@@ -247,9 +234,7 @@ def train_step(self, batch) -> None:
247234
# ) as grad_acc:
248235
labels = batch.pop("labels")
249236
loss = self.forward_backward(batch, labels)
250-
loss = loss.item()
251237

252-
record_metric("ForgeSFTRecipe/train_step/loss", loss, Reduce.MEAN)
253238
logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}")
254239
# self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
255240
# self.pbar.update(1)
@@ -266,25 +251,14 @@ async def train(self) -> None:
266251

267252
while self.current_step < self.num_training_steps:
268253
batch = next(dataloader)
269-
270-
# Pop and record metrics from batch before moving to device
271-
self.record_batch_metrics(batch.pop("metrics", []))
272-
record_metric("ForgeSFTRecipe/train/step", self.current_step, Reduce.MEAN)
273-
274254
# Move tensors to the appropriate device
275255
for k, v in batch.items():
276256
if isinstance(v, torch.Tensor):
277257
batch[k] = v.to("cuda") # TODO: hardcoded for now
278-
279258
self.train_step(batch)
280259
# self.profiler.step()
281260
self.current_step += 1
282261

283-
# Flush metrics
284-
if self._rank == 0:
285-
logger.debug(f"Flushing metrics at step {self.current_step}")
286-
await self.mlogger.flush.call_one(global_step=self.current_step)
287-
288262
self.checkpointer.save(
289263
curr_step=self.current_step,
290264
last_step=self.current_step == self.num_training_steps,
@@ -296,23 +270,16 @@ async def train(self) -> None:
296270
async def cleanup(self) -> None:
297271
if self.checkpointer:
298272
self.checkpointer.close()
299-
if getattr(self, "mlogger", None):
300-
await self.mlogger.shutdown.call_one()
273+
if self.metric_logger:
274+
self.metric_logger.close()
301275

302276
def __repr__(self) -> str:
303277
return "Trainer"
304278

305279

306280
async def run(cfg: DictConfig) -> None:
307-
308-
logging.info("Spawning recipe...")
281+
logging.info("Spawing recipe...")
309282
process_cfg = cfg.pop("processes")
310-
311-
# Initialize metric logger in main process
312-
metric_logging_cfg = cfg.get("metric_logging", {})
313-
mlogger = await get_or_create_metric_logger(process_name="Controller")
314-
await mlogger.init_backends.call_one(metric_logging_cfg)
315-
316283
recipe = await ForgeSFTRecipe.options(**process_cfg).as_actor(cfg)
317284

318285
logging.info("Created recipe, running setup.")
@@ -323,7 +290,6 @@ async def run(cfg: DictConfig) -> None:
323290

324291
logging.info("Done training. Clean up")
325292
await recipe.cleanup.call()
326-
327293
await recipe.mesh.stop()
328294
logging.info("All done!")
329295

apps/sft/qwen3_8b.yaml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ parallelism:
4545
checkpoint:
4646
enable: true
4747
folder: ./checkpoint # The folder to save checkpoints to.
48-
initial_load_path: hf://${model_name} # The path to load the initial checkpoint from. Ignored if `folder` exists.
48+
initial_load_path: hf://${model} # The path to load the initial checkpoint from. Ignored if `folder` exists.
4949
initial_load_in_hf: true # If true, interpret initial_load_path as a HuggingFace model repo
5050
last_save_in_hf: true
5151
interval: 500
@@ -55,12 +55,6 @@ activation_checkpoint:
5555
mode: selective
5656
selective_ac_option: op
5757

58-
metric_logging:
59-
wandb:
60-
project: sft-training
61-
group: sft_exp_${oc.env:USER}
62-
logging_mode: global_reduce # global_reduce, per_rank_reduce, per_rank_no_reduce
63-
6458
# profiling:
6559
# enable_profiling: false
6660

src/forge/data/__init__.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,6 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66
from .collate import collate_packed
7-
from .metric_transform import DefaultDatasetMetricTransform, MetricTransform
87
from .utils import CROSS_ENTROPY_IGNORE_IDX
98

10-
__all__ = [
11-
"collate_packed",
12-
"CROSS_ENTROPY_IGNORE_IDX",
13-
"MetricTransform",
14-
"DefaultDatasetMetricTransform",
15-
]
9+
__all__ = ["collate_packed", "CROSS_ENTROPY_IGNORE_IDX"]
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from .metric_agg_handlers import (
8+
AggregationHandler,
9+
CategoricalCountAggHandler,
10+
MaxAggHandler,
11+
MeanAggHandler,
12+
MetricState,
13+
MinAggHandler,
14+
StatsAggHandler,
15+
SumAggHandler,
16+
)
17+
from .metric_aggregator import MetricsAggregator
18+
from .metric_transform import (
19+
AggregationType,
20+
DefaultTrainingMetricTransform,
21+
Metric,
22+
MetricTransform,
23+
)
24+
25+
__all__ = [
26+
"AggregationType",
27+
"AggregationHandler",
28+
"CategoricalCountAggHandler",
29+
"DefaultTrainingMetricTransform",
30+
"StatsAggHandler",
31+
"MaxAggHandler",
32+
"MeanAggHandler",
33+
"Metric",
34+
"MetricState",
35+
"MetricsAggregator",
36+
"MetricTransform",
37+
"MinAggHandler",
38+
"SumAggHandler",
39+
]

0 commit comments

Comments
 (0)