2727from forge .data .datasets .packed import PackedDataset , TextPacker
2828from forge .data .datasets .sft_dataset import AlpacaToMessages , sft_iterable_dataset
2929from forge .data .tokenizer import HuggingFaceModelTokenizer
30- from forge .observability import get_or_create_metric_logger , record_metric , Reduce
3130from forge .util .config import parse
3231
3332from monarch .actor import current_rank , current_size , endpoint
@@ -78,6 +77,7 @@ def __init__(self, config: DictConfig):
7877
7978 self .current_step = 0
8079 self .num_training_steps = job_config .training .steps
80+ self .metric_logger = None # TODO: fix this
8181 self .gradient_accumulation_steps = 1 # Example value, adjust as needed
8282 self ._rank = current_rank ().rank
8383 self ._size = math .prod (current_size ().values ())
@@ -109,22 +109,9 @@ def _init_dist(self):
109109 os .environ .update (env )
110110 logger .info ("env: {}" .format (env ))
111111
112- async def setup_metric_logger (self ):
113- """Initialization happens in the main process. Here we just retrieve it"""
114- mlogger = await get_or_create_metric_logger ()
115- return mlogger
116-
117- def record_batch_metrics (self , data_metrics : list ):
118- """Since the dataloader creates new processes, we dont call `record_metric` in the dataset.
119- Instead, pop the metrics from the batch and record them here."""
120- for metric in data_metrics :
121- record_metric (metric .key , metric .value , metric .reduction )
122-
123112 @endpoint
124113 async def setup (self ):
125114 self .train_dataloader = self .setup_data ()
126- self .mlogger = await self .setup_metric_logger ()
127-
128115 # self.train_dataloader = self.setup_data(
129116 # self.train_config.train_dataset_config,
130117 # self.train_config.train_dataloader_config,
@@ -247,9 +234,7 @@ def train_step(self, batch) -> None:
247234 # ) as grad_acc:
248235 labels = batch .pop ("labels" )
249236 loss = self .forward_backward (batch , labels )
250- loss = loss .item ()
251237
252- record_metric ("ForgeSFTRecipe/train_step/loss" , loss , Reduce .MEAN )
253238 logger .info (f"{ self .current_step } / { self .num_training_steps } |Loss: { loss } " )
254239 # self.pbar.set_description(f"{self.current_step}|Loss: {loss}")
255240 # self.pbar.update(1)
@@ -266,25 +251,14 @@ async def train(self) -> None:
266251
267252 while self .current_step < self .num_training_steps :
268253 batch = next (dataloader )
269-
270- # Pop and record metrics from batch before moving to device
271- self .record_batch_metrics (batch .pop ("metrics" , []))
272- record_metric ("ForgeSFTRecipe/train/step" , self .current_step , Reduce .MEAN )
273-
274254 # Move tensors to the appropriate device
275255 for k , v in batch .items ():
276256 if isinstance (v , torch .Tensor ):
277257 batch [k ] = v .to ("cuda" ) # TODO: hardcoded for now
278-
279258 self .train_step (batch )
280259 # self.profiler.step()
281260 self .current_step += 1
282261
283- # Flush metrics
284- if self ._rank == 0 :
285- logger .debug (f"Flushing metrics at step { self .current_step } " )
286- await self .mlogger .flush .call_one (global_step = self .current_step )
287-
288262 self .checkpointer .save (
289263 curr_step = self .current_step ,
290264 last_step = self .current_step == self .num_training_steps ,
@@ -296,23 +270,16 @@ async def train(self) -> None:
296270 async def cleanup (self ) -> None :
297271 if self .checkpointer :
298272 self .checkpointer .close ()
299- if getattr ( self , "mlogger" , None ) :
300- await self .mlogger . shutdown . call_one ()
273+ if self . metric_logger :
274+ self .metric_logger . close ()
301275
302276 def __repr__ (self ) -> str :
303277 return "Trainer"
304278
305279
306280async def run (cfg : DictConfig ) -> None :
307-
308- logging .info ("Spawning recipe..." )
281+ logging .info ("Spawing recipe..." )
309282 process_cfg = cfg .pop ("processes" )
310-
311- # Initialize metric logger in main process
312- metric_logging_cfg = cfg .get ("metric_logging" , {})
313- mlogger = await get_or_create_metric_logger (process_name = "Controller" )
314- await mlogger .init_backends .call_one (metric_logging_cfg )
315-
316283 recipe = await ForgeSFTRecipe .options (** process_cfg ).as_actor (cfg )
317284
318285 logging .info ("Created recipe, running setup." )
@@ -323,7 +290,6 @@ async def run(cfg: DictConfig) -> None:
323290
324291 logging .info ("Done training. Clean up" )
325292 await recipe .cleanup .call ()
326-
327293 await recipe .mesh .stop ()
328294 logging .info ("All done!" )
329295
0 commit comments