Skip to content

Commit 6ec9733

Browse files
author
Felipe Mello
committed
nits
1 parent 5d68cf6 commit 6ec9733

File tree

6 files changed

+14
-32
lines changed

6 files changed

+14
-32
lines changed

apps/sft/main.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
import math
1717
import os
1818
import sys
19-
import warnings
2019
from functools import partial
2120
from typing import Any
2221

@@ -117,14 +116,16 @@ async def setup_metric_logger(self):
117116
return mlogger
118117

119118
def record_batch_metrics(self, data_metrics: list):
120-
"""Record dataset metrics using the observability system."""
119+
"""Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
120+
Instead, pop the metrics from the batch and record them here."""
121121
for metric in data_metrics:
122122
record_metric(metric.key, metric.value, metric.reduction)
123123

124124
@endpoint
125125
async def setup(self):
126126
self.train_dataloader = self.setup_data()
127127
self.mlogger = await self.setup_metric_logger()
128+
128129
# self.train_dataloader = self.setup_data(
129130
# self.train_config.train_dataset_config,
130131
# self.train_config.train_dataloader_config,
@@ -268,9 +269,7 @@ async def train(self) -> None:
268269

269270
# Pop and record metrics from batch before moving to device
270271
self.record_batch_metrics(batch.pop("metrics", []))
271-
record_metric(
272-
"ForgeSFTRecipe/train_step/step", self.current_step, Reduce.MEAN
273-
)
272+
record_metric("ForgeSFTRecipe/train/step", self.current_step, Reduce.MEAN)
274273

275274
# Move tensors to the appropriate device
276275
for k, v in batch.items():
@@ -306,23 +305,11 @@ def __repr__(self) -> str:
306305

307306
async def run(cfg: DictConfig) -> None:
308307

309-
# TODO (allenwang28) Required for metric logging to work. Should be removed when V1 becomes default
310-
MONARCH_HOSTMESH_V1 = os.getenv("MONARCH_HOSTMESH_V1")
311-
if MONARCH_HOSTMESH_V1 != "1":
312-
warnings.warn(
313-
"MONARCH_HOSTMESH_V1 is set to {MONARCH_HOSTMESH_V1}. Setting it to '1' for SFT v2 to work properly. ",
314-
UserWarning,
315-
stacklevel=2,
316-
)
317-
os.environ["MONARCH_HOSTMESH_V1"] = "1"
318-
319308
logging.info("Spawning recipe...")
320309
process_cfg = cfg.pop("processes")
321310

322311
# Initialize metric logger in main process
323-
metric_logging_cfg = cfg.get(
324-
"metric_logging", {"console": {"logging_mode": "global_reduce"}}
325-
)
312+
metric_logging_cfg = cfg.get("metric_logging", {})
326313
mlogger = await get_or_create_metric_logger(process_name="Controller")
327314
await mlogger.init_backends.call_one(metric_logging_cfg)
328315

@@ -337,8 +324,6 @@ async def run(cfg: DictConfig) -> None:
337324
logging.info("Done training. Clean up")
338325
await recipe.cleanup.call()
339326

340-
# Shutdown metric logger
341-
await mlogger.shutdown.call_one()
342327
await recipe.mesh.stop()
343328
logging.info("All done!")
344329

src/forge/data/datasets/hf_dataset.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,8 +232,9 @@ def __iter__(self) -> Iterator[dict[str, Any]]:
232232
# .map is applied lazily and the advantage would be to leverage caching.
233233
sample = self._apply_transforms(sample)
234234

235-
# Track the number of epochs completed for each dataset.
236-
# This is especially useful when interleaving multiple datasets.
235+
# Track the number of epochs completed for each dataset. This is
236+
# especially useful when interleaving multiple datasets, but
237+
# also necessary to track dataset-level metrics.
237238
if "metrics" not in sample:
238239
sample["metrics"] = []
239240

src/forge/observability/metric_actors.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
MetricCollector,
2929
reduce_metrics_states,
3030
)
31-
from forge.observability.utils import detect_actor_name_from_call_stack
3231

3332

3433
logger = logging.getLogger(__name__)
@@ -84,9 +83,6 @@ async def get_or_create_metric_logger(
8483
await mlogger.shutdown.call_one()
8584
"""
8685

87-
if process_name is None:
88-
process_name = detect_actor_name_from_call_stack()
89-
9086
# Get or create the singleton global logger
9187
global _global_logger
9288

src/forge/observability/metrics.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,11 +128,11 @@ def reduce_metrics_states(states: list[dict[str, dict[str, Any]]]) -> list[Metri
128128
states is more precise than merging locally reduced metrics.
129129
130130
Args:
131-
states (list[dict[str, dict[str, Any]]]): list of state of one or more metrics,
131+
states (list[dict[str, dict[str, Any]]]): List of state of one or more metrics,
132132
normally retrieved using `forge.observability.metrics.MetricAccumulator.get_state()`.
133133
134134
Returns:
135-
list[Metric]: list of reduced metrics
135+
list[Metric]: List of reduced metrics
136136
137137
Example:
138138
states = [

tests/unit_tests/datasets/test_hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@ def test_default_dataset_name(self, small_dataset_file):
113113
split="train",
114114
# dataset_name not provided - should auto-generate
115115
seed=SEED,
116-
metric_transform=None, # Now using new observability system
116+
metric_transform=DefaultDatasetMetricTransform(),
117117
num_shards_per_rank=4,
118118
)
119119

@@ -131,7 +131,7 @@ def test_default_dataset_name(self, small_dataset_file):
131131
dataset_name="my_dataset",
132132
weight=custom_weight,
133133
seed=SEED,
134-
metric_transform=None, # Now using new observability system
134+
metric_transform=DefaultDatasetMetricTransform(),
135135
num_shards_per_rank=4,
136136
)
137137

@@ -317,7 +317,7 @@ def create_loader():
317317
dataset_name="epoch_test",
318318
seed=SEED,
319319
shuffle_buffer_size=0, # No shuffle for determinism
320-
metric_transform=None, # Now using new observability system
320+
metric_transform=DefaultDatasetMetricTransform(),
321321
num_shards_per_rank=2,
322322
)
323323
loader = StatefulDataLoader(

tests/unit_tests/datasets/test_interleaved.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -308,7 +308,7 @@ def test_metrics_aggregation(
308308
if "metrics" in sample:
309309
collected_metrics.extend(sample["metrics"])
310310

311-
# Count metrics by dataset name (using new metric key)
311+
# Count metrics by dataset name
312312
ds1_samples_processed = sum(
313313
1
314314
for m in collected_metrics

0 commit comments

Comments
 (0)