Skip to content

Commit 1f02ea6

Browse files
authored
Adds Validation Adds validation loss + exposes it in the API (#685)
* adds validation loss * linting + testing * more linting * more linting * more testing
1 parent 320b794 commit 1f02ea6

File tree

5 files changed

+318
-25
lines changed

5 files changed

+318
-25
lines changed

src/instructlab/training/config.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,3 +340,27 @@ class TrainingArgs(BaseModel):
340340
default=None,
341341
description="Directory for TensorBoard logs. Defaults to ckpt_output_dir if not specified.",
342342
)
343+
344+
validation_split: float = Field(
345+
default=0.0,
346+
description="Fraction of data to use for validation (0.0 to 1.0). 0.0 disables validation.",
347+
)
348+
349+
validation_frequency: Optional[int] = Field(
350+
default=None,
351+
description="How often to evaluate validation loss (in training steps). Required when validation_split > 0.",
352+
)
353+
354+
@model_validator(mode="after")
355+
def validate_validation_config(self):
356+
if not 0.0 <= self.validation_split < 1.0:
357+
raise ValueError(
358+
f"validation_split must be in [0.0, 1.0), got {self.validation_split}"
359+
)
360+
if self.validation_split > 0.0 and (
361+
self.validation_frequency is None or self.validation_frequency <= 0
362+
):
363+
raise ValueError(
364+
"validation_frequency must be provided and > 0 when validation_split > 0"
365+
)
366+
return self

src/instructlab/training/main_ds.py

Lines changed: 134 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,10 +84,95 @@
8484
logger = logging.getLogger(__name__)
8585

8686

87+
def compute_validation_loss(model, val_data_loader, device):
88+
"""Compute validation loss on the validation dataset.
89+
90+
Follows the same loss computation as training (manual per-token CE loss)
91+
but without backward passes or gradient accumulation.
92+
93+
Args:
94+
model: The model to evaluate
95+
val_data_loader: Validation data loader
96+
device: Device to run evaluation on
97+
98+
Returns:
99+
dict: Dictionary containing validation metrics, empty if no data
100+
"""
101+
if val_data_loader is None:
102+
return {}
103+
104+
local_rank = int(os.environ["LOCAL_RANK"])
105+
base_logger = logging.getLogger("instructlab.training")
106+
107+
if local_rank == 0:
108+
base_logger.info("Computing validation loss...")
109+
110+
model.eval()
111+
total_loss = 0.0
112+
total_tokens = 0
113+
114+
with torch.no_grad():
115+
for batch in val_data_loader:
116+
batch_loss = 0.0
117+
118+
for mb in batch:
119+
# Prepare model inputs
120+
input_ids = mb["input_ids"].to(device)
121+
labels = mb["labels"].to(device)
122+
123+
model_inputs = {"input_ids": input_ids, "labels": labels}
124+
125+
if "position_ids" in mb:
126+
model_inputs["position_ids"] = mb["position_ids"].to(device)
127+
if "attention_mask" in mb:
128+
model_inputs["attention_mask"] = mb["attention_mask"].to(device)
129+
130+
# Forward pass
131+
output = model(**model_inputs, use_cache=False)
132+
133+
# Manual CE loss computation (same as model.compute_loss but without scaling)
134+
logits = output.logits
135+
shift_logits = logits[..., :-1, :].contiguous()
136+
shift_labels = labels[..., 1:].contiguous()
137+
shift_logits = shift_logits.view(-1, shift_logits.size(-1))
138+
shift_labels = shift_labels.view(-1)
139+
140+
loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
141+
token_losses = loss_fct(shift_logits, shift_labels)
142+
valid_tokens = shift_labels != -100
143+
batch_loss += token_losses[valid_tokens].sum().item()
144+
145+
total_loss += batch_loss
146+
total_tokens += batch[0]["batch_num_loss_counted_tokens"]
147+
148+
# Single reduction after all batches (SUM is associative)
149+
loss_tensor = torch.tensor(total_loss, device=device)
150+
tokens_tensor = torch.tensor(total_tokens, device=device, dtype=torch.long)
151+
dist.all_reduce(loss_tensor, op=dist.ReduceOp.SUM)
152+
dist.all_reduce(tokens_tensor, op=dist.ReduceOp.SUM)
153+
total_loss = loss_tensor.item()
154+
total_tokens = int(tokens_tensor.item())
155+
156+
val_metrics = {}
157+
if total_tokens > 0:
158+
avg_val_loss = total_loss / total_tokens
159+
val_metrics = {
160+
"val_loss": avg_val_loss,
161+
"val_num_tokens": total_tokens,
162+
}
163+
if local_rank == 0:
164+
base_logger.info("Validation loss: %.6f", avg_val_loss)
165+
166+
model.train()
167+
return val_metrics
168+
169+
87170
def train(
88171
args,
89172
model: Model,
90173
accelerator: Accelerator,
174+
val_data_loader=None,
175+
validation_frequency=None,
91176
):
92177
model.train()
93178

@@ -193,6 +278,22 @@ def train(
193278
extra={"step": global_step},
194279
)
195280

281+
# Compute validation loss if it's time to validate
282+
if (
283+
val_data_loader is not None
284+
and validation_frequency is not None
285+
and global_step % validation_frequency == 0
286+
):
287+
torch_device = torch.device("cuda", local_rank)
288+
val_metrics = compute_validation_loss(
289+
model, val_data_loader, torch_device
290+
)
291+
if val_metrics and local_rank == 0:
292+
metric_logger.info(
293+
val_metrics,
294+
extra={"step": global_step},
295+
)
296+
196297
if args.save_samples > 0 and (samples_seen % args.save_samples == 0):
197298
base_logger.debug(f"Saving checkpoint at step {global_step}")
198299
save_checkpoint(
@@ -377,7 +478,10 @@ def main(args):
377478

378479
pad_token_id = tokenizer.pad_token_id if tokenizer.pad_token_id is not None else 0
379480

380-
train_loader = get_data_loader(
481+
validation_split = getattr(args, "validation_split", 0.0)
482+
validation_frequency = getattr(args, "validation_frequency", None)
483+
484+
train_loader, val_loader = get_data_loader(
381485
data_path=args.data_path,
382486
batch_size=batch_size,
383487
max_tokens_per_gpu=packing_max_batch_len,
@@ -388,6 +492,7 @@ def main(args):
388492
flash_enabled=flash_enabled,
389493
pad_token_id=pad_token_id,
390494
pretraining_config=getattr(args, "pretraining_config", None),
495+
validation_split=validation_split,
391496
)
392497

393498
if args.local_rank == 0:
@@ -454,6 +559,8 @@ def main(args):
454559
args,
455560
model=m,
456561
accelerator=accelerator,
562+
val_data_loader=val_loader,
563+
validation_frequency=validation_frequency,
457564
)
458565

459566
dist.barrier()
@@ -585,6 +692,12 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
585692
if train_args.tensorboard_log_dir is not None:
586693
command.append(f"--tensorboard_log_dir={train_args.tensorboard_log_dir}")
587694

695+
# Validation parameters
696+
if train_args.validation_split > 0.0:
697+
command.append(f"--validation_split={train_args.validation_split}")
698+
if train_args.validation_frequency is not None:
699+
command.append(f"--validation_frequency={train_args.validation_frequency}")
700+
588701
if train_args.pretraining_config is not None:
589702
command.append(f"--block-size={train_args.pretraining_config.block_size}")
590703
command.append(
@@ -961,7 +1074,27 @@ def run_training(torch_args: TorchrunArgs, train_args: TrainingArgs) -> None:
9611074
default=1e-8,
9621075
help="Epsilon for numerical stability in AdamW optimizer.",
9631076
)
1077+
parser.add_argument(
1078+
"--validation_split",
1079+
type=float,
1080+
default=0.0,
1081+
help="Fraction of data to use for validation (0.0 to 1.0). 0.0 disables validation.",
1082+
)
1083+
parser.add_argument(
1084+
"--validation_frequency",
1085+
type=int,
1086+
default=None,
1087+
help="How often to evaluate validation loss (in training steps). Required when validation_split > 0.",
1088+
)
9641089
args = parser.parse_args()
1090+
1091+
if args.validation_split > 0.0 and (
1092+
args.validation_frequency is None or args.validation_frequency <= 0
1093+
):
1094+
parser.error(
1095+
"--validation_frequency must be provided and positive when --validation_split > 0"
1096+
)
1097+
9651098
if args.document_column_name is not None and args.block_size is None:
9661099
parser.error("--document-column-name requires --block-size to be specified.")
9671100

0 commit comments

Comments
 (0)