8484logger = logging .getLogger (__name__ )
8585
8686
87+ def compute_validation_loss (model , val_data_loader , device ):
88+ """Compute validation loss on the validation dataset.
89+
90+ Follows the same loss computation as training (manual per-token CE loss)
91+ but without backward passes or gradient accumulation.
92+
93+ Args:
94+ model: The model to evaluate
95+ val_data_loader: Validation data loader
96+ device: Device to run evaluation on
97+
98+ Returns:
99+ dict: Dictionary containing validation metrics, empty if no data
100+ """
101+ if val_data_loader is None :
102+ return {}
103+
104+ local_rank = int (os .environ ["LOCAL_RANK" ])
105+ base_logger = logging .getLogger ("instructlab.training" )
106+
107+ if local_rank == 0 :
108+ base_logger .info ("Computing validation loss..." )
109+
110+ model .eval ()
111+ total_loss = 0.0
112+ total_tokens = 0
113+
114+ with torch .no_grad ():
115+ for batch in val_data_loader :
116+ batch_loss = 0.0
117+
118+ for mb in batch :
119+ # Prepare model inputs
120+ input_ids = mb ["input_ids" ].to (device )
121+ labels = mb ["labels" ].to (device )
122+
123+ model_inputs = {"input_ids" : input_ids , "labels" : labels }
124+
125+ if "position_ids" in mb :
126+ model_inputs ["position_ids" ] = mb ["position_ids" ].to (device )
127+ if "attention_mask" in mb :
128+ model_inputs ["attention_mask" ] = mb ["attention_mask" ].to (device )
129+
130+ # Forward pass
131+ output = model (** model_inputs , use_cache = False )
132+
133+ # Manual CE loss computation (same as model.compute_loss but without scaling)
134+ logits = output .logits
135+ shift_logits = logits [..., :- 1 , :].contiguous ()
136+ shift_labels = labels [..., 1 :].contiguous ()
137+ shift_logits = shift_logits .view (- 1 , shift_logits .size (- 1 ))
138+ shift_labels = shift_labels .view (- 1 )
139+
140+ loss_fct = torch .nn .CrossEntropyLoss (reduction = "none" )
141+ token_losses = loss_fct (shift_logits , shift_labels )
142+ valid_tokens = shift_labels != - 100
143+ batch_loss += token_losses [valid_tokens ].sum ().item ()
144+
145+ total_loss += batch_loss
146+ total_tokens += batch [0 ]["batch_num_loss_counted_tokens" ]
147+
148+ # Single reduction after all batches (SUM is associative)
149+ loss_tensor = torch .tensor (total_loss , device = device )
150+ tokens_tensor = torch .tensor (total_tokens , device = device , dtype = torch .long )
151+ dist .all_reduce (loss_tensor , op = dist .ReduceOp .SUM )
152+ dist .all_reduce (tokens_tensor , op = dist .ReduceOp .SUM )
153+ total_loss = loss_tensor .item ()
154+ total_tokens = int (tokens_tensor .item ())
155+
156+ val_metrics = {}
157+ if total_tokens > 0 :
158+ avg_val_loss = total_loss / total_tokens
159+ val_metrics = {
160+ "val_loss" : avg_val_loss ,
161+ "val_num_tokens" : total_tokens ,
162+ }
163+ if local_rank == 0 :
164+ base_logger .info ("Validation loss: %.6f" , avg_val_loss )
165+
166+ model .train ()
167+ return val_metrics
168+
169+
87170def train (
88171 args ,
89172 model : Model ,
90173 accelerator : Accelerator ,
174+ val_data_loader = None ,
175+ validation_frequency = None ,
91176):
92177 model .train ()
93178
@@ -193,6 +278,22 @@ def train(
193278 extra = {"step" : global_step },
194279 )
195280
281+ # Compute validation loss if it's time to validate
282+ if (
283+ val_data_loader is not None
284+ and validation_frequency is not None
285+ and global_step % validation_frequency == 0
286+ ):
287+ torch_device = torch .device ("cuda" , local_rank )
288+ val_metrics = compute_validation_loss (
289+ model , val_data_loader , torch_device
290+ )
291+ if val_metrics and local_rank == 0 :
292+ metric_logger .info (
293+ val_metrics ,
294+ extra = {"step" : global_step },
295+ )
296+
196297 if args .save_samples > 0 and (samples_seen % args .save_samples == 0 ):
197298 base_logger .debug (f"Saving checkpoint at step { global_step } " )
198299 save_checkpoint (
@@ -377,7 +478,10 @@ def main(args):
377478
378479 pad_token_id = tokenizer .pad_token_id if tokenizer .pad_token_id is not None else 0
379480
380- train_loader = get_data_loader (
481+ validation_split = getattr (args , "validation_split" , 0.0 )
482+ validation_frequency = getattr (args , "validation_frequency" , None )
483+
484+ train_loader , val_loader = get_data_loader (
381485 data_path = args .data_path ,
382486 batch_size = batch_size ,
383487 max_tokens_per_gpu = packing_max_batch_len ,
@@ -388,6 +492,7 @@ def main(args):
388492 flash_enabled = flash_enabled ,
389493 pad_token_id = pad_token_id ,
390494 pretraining_config = getattr (args , "pretraining_config" , None ),
495+ validation_split = validation_split ,
391496 )
392497
393498 if args .local_rank == 0 :
@@ -454,6 +559,8 @@ def main(args):
454559 args ,
455560 model = m ,
456561 accelerator = accelerator ,
562+ val_data_loader = val_loader ,
563+ validation_frequency = validation_frequency ,
457564 )
458565
459566 dist .barrier ()
@@ -585,6 +692,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
585692 if train_args .tensorboard_log_dir is not None :
586693 command .append (f"--tensorboard_log_dir={ train_args .tensorboard_log_dir } " )
587694
695+ # Validation parameters
696+ if train_args .validation_split > 0.0 :
697+ command .append (f"--validation_split={ train_args .validation_split } " )
698+ if train_args .validation_frequency is not None :
699+ command .append (f"--validation_frequency={ train_args .validation_frequency } " )
700+
588701 if train_args .pretraining_config is not None :
589702 command .append (f"--block-size={ train_args .pretraining_config .block_size } " )
590703 command .append (
@@ -961,7 +1074,27 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
9611074 default = 1e-8 ,
9621075 help = "Epsilon for numerical stability in AdamW optimizer." ,
9631076 )
1077+ parser .add_argument (
1078+ "--validation_split" ,
1079+ type = float ,
1080+ default = 0.0 ,
1081+ help = "Fraction of data to use for validation (0.0 to 1.0). 0.0 disables validation." ,
1082+ )
1083+ parser .add_argument (
1084+ "--validation_frequency" ,
1085+ type = int ,
1086+ default = None ,
1087+ help = "How often to evaluate validation loss (in training steps). Required when validation_split > 0." ,
1088+ )
9641089 args = parser .parse_args ()
1090+
1091+ if args .validation_split > 0.0 and (
1092+ args .validation_frequency is None or args .validation_frequency <= 0
1093+ ):
1094+ parser .error (
1095+ "--validation_frequency must be provided and positive when --validation_split > 0"
1096+ )
1097+
9651098 if args .document_column_name is not None and args .block_size is None :
9661099 parser .error ("--document-column-name requires --block-size to be specified." )
9671100
0 commit comments