1616import math
1717import os
1818import sys
19- import warnings
2019from functools import partial
2120from typing import Any
2221
@@ -117,14 +116,16 @@ async def setup_metric_logger(self):
117116 return mlogger
118117
119118 def record_batch_metrics (self , data_metrics : list ):
120- """Record dataset metrics using the observability system."""
119+ """Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
120+ Instead, pop the metrics from the batch and record them here."""
121121 for metric in data_metrics :
122122 record_metric (metric .key , metric .value , metric .reduction )
123123
124124 @endpoint
125125 async def setup (self ):
126126 self .train_dataloader = self .setup_data ()
127127 self .mlogger = await self .setup_metric_logger ()
128+
128129 # self.train_dataloader = self.setup_data(
129130 # self.train_config.train_dataset_config,
130131 # self.train_config.train_dataloader_config,
@@ -268,9 +269,7 @@ async def train(self) -> None:
268269
269270 # Pop and record metrics from batch before moving to device
270271 self .record_batch_metrics (batch .pop ("metrics" , []))
271- record_metric (
272- "ForgeSFTRecipe/train_step/step" , self .current_step , Reduce .MEAN
273- )
272+ record_metric ("ForgeSFTRecipe/train/step" , self .current_step , Reduce .MEAN )
274273
275274 # Move tensors to the appropriate device
276275 for k , v in batch .items ():
@@ -306,23 +305,11 @@ def __repr__(self) -> str:
306305
307306async def run (cfg : DictConfig ) -> None :
308307
309- # TODO (allenwang28) Required for metric logging to work. Should be removed when V1 becomes default
310- MONARCH_HOSTMESH_V1 = os .getenv ("MONARCH_HOSTMESH_V1" )
311- if MONARCH_HOSTMESH_V1 != "1" :
312- warnings .warn (
313- "MONARCH_HOSTMESH_V1 is set to {MONARCH_HOSTMESH_V1}. Setting it to '1' for SFT v2 to work properly. " ,
314- UserWarning ,
315- stacklevel = 2 ,
316- )
317- os .environ ["MONARCH_HOSTMESH_V1" ] = "1"
318-
319308 logging .info ("Spawning recipe..." )
320309 process_cfg = cfg .pop ("processes" )
321310
322311 # Initialize metric logger in main process
323- metric_logging_cfg = cfg .get (
324- "metric_logging" , {"console" : {"logging_mode" : "global_reduce" }}
325- )
312+ metric_logging_cfg = cfg .get ("metric_logging" , {})
326313 mlogger = await get_or_create_metric_logger (process_name = "Controller" )
327314 await mlogger .init_backends .call_one (metric_logging_cfg )
328315
@@ -337,8 +324,6 @@ async def run(cfg: DictConfig) -> None:
337324 logging .info ("Done training. Clean up" )
338325 await recipe .cleanup .call ()
339326
340- # Shutdown metric logger
341- await mlogger .shutdown .call_one ()
342327 await recipe .mesh .stop ()
343328 logging .info ("All done!" )
344329
0 commit comments