1818from forge .data .datasets .packed import PackedDataset , TextPacker
1919from forge .data .datasets .sft_dataset import AlpacaToMessages , sft_iterable_dataset
2020from forge .data .tokenizer import HuggingFaceModelTokenizer
21+ from forge .data .utils import batch_to_device , CROSS_ENTROPY_IGNORE_IDX
2122
2223from omegaconf import DictConfig , OmegaConf
2324from torch import nn
25+
2426from torchdata .stateful_dataloader import StatefulDataLoader
2527from torchtitan .components .loss import LossFunction
2628from torchtitan .components .lr_scheduler import LRSchedulersContainer
3032from torchtitan .experiments .forge .job_config import ForgeJobConfig
3133from tqdm import tqdm
3234
35+
3336# stubs for now
3437Checkpointer = Any
3538Dataloader = Any
@@ -63,7 +66,16 @@ def __init__(self, job_config: ForgeJobConfig):
6366 self .metric_logger = None # TODO: fix this
6467
6568 def setup (self ):
66- self .train_dataloader = self .setup_data ()
69+ self .train_dataloader = self .setup_data (
70+ self .job_config .dataset ,
71+ batch_size = self .job_config .training .local_batch_size ,
72+ )
73+
74+ self .val_dataloader = self .setup_data (
75+ self .job_config .dataset_val ,
76+ batch_size = self .job_config .validation .local_batch_size ,
77+ )
78+
6779 # self.train_dataloader = self.setup_data(
6880 # self.train_config.train_dataset_config,
6981 # self.train_config.train_dataloader_config,
@@ -79,7 +91,7 @@ def setup(self):
7991 # self.profiler = self.setup_profiler(self.train_config.profiler_config)
8092 # self.logger = self.setup_logger(self.train_config.logger_config)
8193
82- def setup_data (self ):
94+ def setup_data (self , dataset_config , batch_size ):
8395 tokenizer = HuggingFaceModelTokenizer (
8496 tokenizer_json_path = os .path .join (
8597 self .job_config .model .hf_assets_path , "tokenizer.json"
@@ -95,8 +107,8 @@ def setup_data(self):
95107 dataset = sft_iterable_dataset (
96108 model_transform = tokenizer ,
97109 message_transform = AlpacaToMessages (),
98- path = "yahma/alpaca-cleaned" ,
99- split = "train" ,
110+ path = dataset_config . path ,
111+ split = dataset_config . split ,
100112 )
101113 packer = TextPacker (padding_idx = 0 )
102114 dataset = PackedDataset (
@@ -106,7 +118,7 @@ def setup_data(self):
106118 )
107119 dataloader = StatefulDataLoader (
108120 dataset = dataset ,
109- batch_size = self . job_config . training . local_batch_size ,
121+ batch_size = batch_size ,
110122 collate_fn = partial (
111123 collate_packed , mask_fn = packer .create_block_mask , device = self .device
112124 ),
@@ -119,7 +131,10 @@ def setup_data(self):
119131 return dataloader
120132
121133 def forward_backward (
122- self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
134+ self ,
135+ input_dict : dict [str , torch .Tensor ],
136+ labels : torch .Tensor ,
137+ do_backward : bool = True ,
123138 ) -> torch .Tensor :
124139 model_parts = self .model_parts
125140 parallel_dims = self .parallel_dims
@@ -145,14 +160,16 @@ def forward_backward(
145160 targets , losses = (
146161 (labels , []) if self .pp_has_last_stage else (None , None )
147162 )
163+ if do_backward :
164+ pp_schedule_fn = self .pp_schedule .step
165+ else :
166+ pp_schedule_fn = self .pp_schedule .eval
148167 if self .pp_has_first_stage :
149- self . pp_schedule . step (
168+ pp_schedule_fn (
150169 inputs , target = targets , losses = losses , input_batch = inputs
151170 )
152171 else :
153- self .pp_schedule .step (
154- target = targets , losses = losses , input_batch = inputs
155- )
172+ pp_schedule_fn (target = targets , losses = losses , input_batch = inputs )
156173
157174 # accumulate losses across pipeline microbatches
158175 # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
@@ -170,7 +187,8 @@ def forward_backward(
170187 loss = self .loss_fn (pred , labels )
171188 # need to free to before bwd to avoid peaking memory
172189 del pred
173- loss .backward ()
190+ if do_backward :
191+ loss .backward ()
174192
175193 return loss
176194
@@ -214,6 +232,52 @@ def train(self) -> None:
214232 last_step = self .current_step == self .num_training_steps ,
215233 )
216234
235+ if (
236+ self .job_config .validation .freq > 0
237+ and self .job_config .validation .steps > 0
238+ and self .current_step % self .job_config .validation .freq == 0
239+ ):
240+ self .validate (self .job_config .validation .steps )
241+
242+ def validate (self , max_steps : int ) -> None :
243+ for m in self .model_parts :
244+ m .eval ()
245+ total_val_loss = torch .tensor (0.0 , device = self .device )
246+ total_val_tokens = torch .tensor (0.0 , device = self .device )
247+ with torch .no_grad ():
248+ val_pbar = tqdm (self .val_dataloader , desc = "Validation" , leave = False )
249+ for batch_idx , batch in enumerate (val_pbar ):
250+ if batch_idx >= max_steps :
251+ break
252+ batch_to_device (batch , self .device )
253+ current_num_tokens = (batch ["labels" ] != CROSS_ENTROPY_IGNORE_IDX ).sum ()
254+ # Compute loss
255+ labels = batch .pop ("labels" )
256+ loss = self .forward_backward (batch , labels , do_backward = False )
257+ val_loss = loss * current_num_tokens
258+ total_val_loss += val_loss
259+ total_val_tokens += current_num_tokens
260+ # Update progress bar description with current average loss
261+ avg_loss_so_far = (
262+ (total_val_loss / total_val_tokens ).item ()
263+ if total_val_tokens > 0
264+ else float ("inf" )
265+ )
266+ val_pbar .set_description (
267+ f"Running validation Loss: { avg_loss_so_far :.4f} "
268+ )
269+ # Aggregate validation metrics across all ranks
270+ torch .distributed .all_reduce (total_val_loss )
271+ torch .distributed .all_reduce (total_val_tokens )
272+ avg_val_loss = (
273+ (total_val_loss / total_val_tokens ).item ()
274+ if total_val_tokens > 0
275+ else float ("inf" )
276+ )
277+ for m in self .model_parts :
278+ m .train ()
279+ print (f"\n Validation loss: { avg_val_loss } " )
280+
217281 def cleanup (self ) -> None :
218282 if self .checkpointer :
219283 self .checkpointer .close ()
0 commit comments