1818from forge .data .datasets .packed import PackedDataset , TextPacker
1919from forge .data .datasets .sft_dataset import AlpacaToMessages , sft_iterable_dataset
2020from forge .data .tokenizer import HuggingFaceModelTokenizer
21+ from forge .data .utils import batch_to_device , CROSS_ENTROPY_IGNORE_IDX
2122from forge .util import get_metric_logger
2223
2324from omegaconf import DictConfig , OmegaConf
2425from torch import nn
26+
2527from torchdata .stateful_dataloader import StatefulDataLoader
2628from torchtitan .components .loss import LossFunction
2729from torchtitan .components .lr_scheduler import LRSchedulersContainer
3133from torchtitan .experiments .forge .job_config import ForgeJobConfig
3234from tqdm import tqdm
3335
36+
3437# stubs for now
3538Checkpointer = Any
3639Dataloader = Any
@@ -64,7 +67,16 @@ def __init__(self, job_config: ForgeJobConfig):
6467 self .metric_logger = get_metric_logger (** job_config .metrics )
6568
6669 def setup (self ):
67- self .train_dataloader = self .setup_data ()
70+ self .train_dataloader = self .setup_data (
71+ self .job_config .dataset ,
72+ batch_size = self .job_config .training .local_batch_size ,
73+ )
74+
75+ self .val_dataloader = self .setup_data (
76+ self .job_config .dataset_val ,
77+ batch_size = self .job_config .validation .local_batch_size ,
78+ )
79+
6880 # self.train_dataloader = self.setup_data(
6981 # self.train_config.train_dataset_config,
7082 # self.train_config.train_dataloader_config,
@@ -80,7 +92,7 @@ def setup(self):
8092 # self.profiler = self.setup_profiler(self.train_config.profiler_config)
8193 # self.logger = self.setup_logger(self.train_config.logger_config)
8294
83- def setup_data (self ):
95+ def setup_data (self , dataset_config , batch_size ):
8496 tokenizer = HuggingFaceModelTokenizer (
8597 tokenizer_json_path = os .path .join (
8698 self .job_config .model .hf_assets_path , "tokenizer.json"
@@ -96,8 +108,8 @@ def setup_data(self):
96108 dataset = sft_iterable_dataset (
97109 model_transform = tokenizer ,
98110 message_transform = AlpacaToMessages (),
99- path = "yahma/alpaca-cleaned" ,
100- split = "train" ,
111+ path = dataset_config . path ,
112+ split = dataset_config . split ,
101113 )
102114 packer = TextPacker (padding_idx = 0 )
103115 dataset = PackedDataset (
@@ -107,7 +119,7 @@ def setup_data(self):
107119 )
108120 dataloader = StatefulDataLoader (
109121 dataset = dataset ,
110- batch_size = self . job_config . training . local_batch_size ,
122+ batch_size = batch_size ,
111123 collate_fn = partial (
112124 collate_packed , mask_fn = packer .create_block_mask , device = self .device
113125 ),
@@ -120,7 +132,10 @@ def setup_data(self):
120132 return dataloader
121133
122134 def forward_backward (
123- self , input_dict : dict [str , torch .Tensor ], labels : torch .Tensor
135+ self ,
136+ input_dict : dict [str , torch .Tensor ],
137+ labels : torch .Tensor ,
138+ do_backward : bool = True ,
124139 ) -> torch .Tensor :
125140 model_parts = self .model_parts
126141 parallel_dims = self .parallel_dims
@@ -146,14 +161,16 @@ def forward_backward(
146161 targets , losses = (
147162 (labels , []) if self .pp_has_last_stage else (None , None )
148163 )
164+ if do_backward :
165+ pp_schedule_fn = self .pp_schedule .step
166+ else :
167+ pp_schedule_fn = self .pp_schedule .eval
149168 if self .pp_has_first_stage :
150- self . pp_schedule . step (
169+ pp_schedule_fn (
151170 inputs , target = targets , losses = losses , input_batch = inputs
152171 )
153172 else :
154- self .pp_schedule .step (
155- target = targets , losses = losses , input_batch = inputs
156- )
173+ pp_schedule_fn (target = targets , losses = losses , input_batch = inputs )
157174
158175 # accumulate losses across pipeline microbatches
159176 # TODO: PP+FSDP unexpectedly puts the loss back to the CPU
@@ -171,7 +188,8 @@ def forward_backward(
171188 loss = self .loss_fn (pred , labels )
172189 # need to free to before bwd to avoid peaking memory
173190 del pred
174- loss .backward ()
191+ if do_backward :
192+ loss .backward ()
175193
176194 return loss
177195
@@ -216,6 +234,52 @@ def train(self) -> None:
216234 last_step = self .current_step == self .num_training_steps ,
217235 )
218236
237+ if (
238+ self .job_config .validation .freq > 0
239+ and self .job_config .validation .steps > 0
240+ and self .current_step % self .job_config .validation .freq == 0
241+ ):
242+ self .validate (self .job_config .validation .steps )
243+
244+ def validate (self , max_steps : int ) -> None :
245+ for m in self .model_parts :
246+ m .eval ()
247+ total_val_loss = torch .tensor (0.0 , device = self .device )
248+ total_val_tokens = torch .tensor (0.0 , device = self .device )
249+ with torch .no_grad ():
250+ val_pbar = tqdm (self .val_dataloader , desc = "Validation" , leave = False )
251+ for batch_idx , batch in enumerate (val_pbar ):
252+ if batch_idx >= max_steps :
253+ break
254+ batch_to_device (batch , self .device )
255+ current_num_tokens = (batch ["labels" ] != CROSS_ENTROPY_IGNORE_IDX ).sum ()
256+ # Compute loss
257+ labels = batch .pop ("labels" )
258+ loss = self .forward_backward (batch , labels , do_backward = False )
259+ val_loss = loss * current_num_tokens
260+ total_val_loss += val_loss
261+ total_val_tokens += current_num_tokens
262+ # Update progress bar description with current average loss
263+ avg_loss_so_far = (
264+ (total_val_loss / total_val_tokens ).item ()
265+ if total_val_tokens > 0
266+ else float ("inf" )
267+ )
268+ val_pbar .set_description (
269+ f"Running validation Loss: { avg_loss_so_far :.4f} "
270+ )
271+ # Aggregate validation metrics across all ranks
272+ torch .distributed .all_reduce (total_val_loss )
273+ torch .distributed .all_reduce (total_val_tokens )
274+ avg_val_loss = (
275+ (total_val_loss / total_val_tokens ).item ()
276+ if total_val_tokens > 0
277+ else float ("inf" )
278+ )
279+ for m in self .model_parts :
280+ m .train ()
281+ print (f"\n Validation loss: { avg_val_loss } " )
282+
219283 def cleanup (self ) -> None :
220284 if self .checkpointer :
221285 self .checkpointer .close ()
0 commit comments