2727from forge .data .datasets .packed import PackedDataset , TextPacker
2828from forge .data .datasets .sft_dataset import AlpacaToMessages , sft_iterable_dataset
2929from forge .data .tokenizer import HuggingFaceModelTokenizer
30+ from forge .observability import get_or_create_metric_logger , record_metric , Reduce
3031from forge .util .config import parse
3132
3233from monarch .actor import current_rank , current_size , endpoint
@@ -77,7 +78,6 @@ def __init__(self, config: DictConfig):
7778
7879 self .current_step = 0
7980 self .num_training_steps = job_config .training .steps
80- self .metric_logger = None # TODO: fix this
8181 self .gradient_accumulation_steps = 1 # Example value, adjust as needed
8282 self ._rank = current_rank ().rank
8383 self ._size = math .prod (current_size ().values ())
@@ -109,9 +109,22 @@ def _init_dist(self):
109109 os .environ .update (env )
110110 logger .info ("env: {}" .format (env ))
111111
112+ async def setup_metric_logger (self ):
113+ """Initialization happens in the main process. Here we just retrieve it"""
114+ mlogger = await get_or_create_metric_logger ()
115+ return mlogger
116+
117+ def record_batch_metrics (self , data_metrics : list ):
118+ """Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
119+ Instead, pop the metrics from the batch and record them here."""
120+ for metric in data_metrics :
121+ record_metric (metric .key , metric .value , metric .reduction )
122+
112123 @endpoint
113124 async def setup (self ):
114125 self .train_dataloader = self .setup_data ()
126+ self .mlogger = await self .setup_metric_logger ()
127+
115128 # self.train_dataloader = self.setup_data(
116129 # self.train_config.train_dataset_config,
117130 # self.train_config.train_dataloader_config,
@@ -234,7 +247,9 @@ def train_step(self, batch) -> None:
234247 # ) as grad_acc:
235248 labels = batch .pop ("labels" )
236249 loss = self .forward_backward (batch , labels )
250+ loss = loss .item ()
237251
252+ record_metric ("ForgeSFTRecipe/train_step/loss" , loss , Reduce .MEAN )
238253 logger .info (f"{ self .current_step } / { self .num_training_steps } |Loss: { loss } " )
239254 # self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
240255 # self.pbar.update(1)
@@ -251,14 +266,25 @@ async def train(self) -> None:
251266
252267 while self .current_step < self .num_training_steps :
253268 batch = next (dataloader )
269+
270+ # Pop and record metrics from batch before moving to device
271+ self .record_batch_metrics (batch .pop ("metrics" , []))
272+ record_metric ("ForgeSFTRecipe/train/step" , self .current_step , Reduce .MEAN )
273+
254274 # Move tensors to the appropriate device
255275 for k , v in batch .items ():
256276 if isinstance (v , torch .Tensor ):
257277 batch [k ] = v .to ("cuda" ) # TODO: hardcoded for now
278+
258279 self .train_step (batch )
259280 # self.profiler.step()
260281 self .current_step += 1
261282
283+ # Flush metrics
284+ if self ._rank == 0 :
285+ logger .debug (f"Flushing metrics at step { self .current_step } " )
286+ await self .mlogger .flush .call_one (global_step = self .current_step )
287+
262288 self .checkpointer .save (
263289 curr_step = self .current_step ,
264290 last_step = self .current_step == self .num_training_steps ,
@@ -270,16 +296,23 @@ async def train(self) -> None:
270296 async def cleanup (self ) -> None :
271297 if self .checkpointer :
272298 self .checkpointer .close ()
273- if self . metric_logger :
274- self .metric_logger . close ()
299+ if getattr ( self , "mlogger" , None ) :
300+ await self .mlogger . shutdown . call_one ()
275301
276302 def __repr__ (self ) -> str :
277303 return "Trainer"
278304
279305
280306async def run (cfg : DictConfig ) -> None :
281- logging .info ("Spawing recipe..." )
307+
308+ logging .info ("Spawning recipe..." )
282309 process_cfg = cfg .pop ("processes" )
310+
311+ # Initialize metric logger in main process
312+ metric_logging_cfg = cfg .get ("metric_logging" , {})
313+ mlogger = await get_or_create_metric_logger (process_name = "Controller" )
314+ await mlogger .init_backends .call_one (metric_logging_cfg )
315+
283316 recipe = await ForgeSFTRecipe .options (** process_cfg ).as_actor (cfg )
284317
285318 logging .info ("Created recipe, running setup." )
@@ -290,6 +323,7 @@ async def run(cfg: DictConfig) -> None:
290323
291324 logging .info ("Done training. Clean up" )
292325 await recipe .cleanup .call ()
326+
293327 await recipe .mesh .stop ()
294328 logging .info ("All done!" )
295329
0 commit comments