Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions llm/src/llm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,14 @@ class TrainingConfig:

check_for_gsa_leak: bool = False

# ── Overfitting detection ──
# Number of consecutive eval cycles where val_loss does not improve
# before an overfitting alert is raised. 0 = disabled.
overfit_patience: int = 5

# Minimum decrease in val_loss required to count as an improvement.
overfit_threshold: float = 0.0


@dataclass
class S3CheckpointConfig:
Expand Down
61 changes: 61 additions & 0 deletions llm/src/llm/pretrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,19 @@ def __init__(self, local_rank: int, c: Config):
self._step_profiler.activate()
self._step_profiler.register_model(self._engine.module)

# ── Overfitting detection state ──
self._train_loss_accum: float = 0.0
self._train_loss_count: int = 0
self._best_val_loss: float = float("inf")
self._overfit_strikes: int = 0

def run(self):
start_epoch, start_step, global_step = self._resume()
max_epochs = self._config.training.max_epochs
max_steps_per_epoch = self._config.training.max_steps_per_epoch
device = self._engine.device
ckpt_interval = self._config.checkpoint.save_interval
eval_interval = self._config.training.eval_interval

for epoch in range(start_epoch, max_epochs):
if self._train_sampler:
Expand Down Expand Up @@ -126,6 +133,47 @@ def run(self):
steps += 1
total_loss += loss.item()

# Accumulate training loss for smoothed average.
self._train_loss_accum += loss.item()
self._train_loss_count += 1

if eval_interval and global_step % eval_interval == 0:
val_loss, val_perplexity = self._validate()

smoothed_train_loss = self._train_loss_accum / max(1, self._train_loss_count)
self._train_loss_accum = 0.0
self._train_loss_count = 0

train_eval_gap = val_loss - smoothed_train_loss

# Overfitting detection watchdog.
overfit_patience = self._config.training.overfit_patience
overfit_threshold = self._config.training.overfit_threshold
if val_loss < self._best_val_loss - overfit_threshold:
self._best_val_loss = val_loss
self._overfit_strikes = 0
else:
self._overfit_strikes += 1

self._logger.log_step(global_step, {
"train_loss": smoothed_train_loss,
"val_loss": val_loss,
"val_perplexity": val_perplexity,
"train_eval_gap": train_eval_gap,
"overfit_strikes": self._overfit_strikes,
"best_val_loss": self._best_val_loss,
})

if overfit_patience > 0 and self._overfit_strikes >= overfit_patience:
self._logger.log_step(global_step, {
"overfitting_detected": True,
"val_loss": val_loss,
"best_val_loss": self._best_val_loss,
"overfit_strikes": self._overfit_strikes,
})

self._engine.train()

metrics.add("loss", loss.item(), pbar=True)
metrics.add("global_step", global_step, pbar=True)
metrics.add("toks/sec", toks_per_sec, pbar=True)
Expand Down Expand Up @@ -395,6 +443,19 @@ def _validate(self) -> tuple[float, float]:
avg_loss = total_loss / steps
avg_perplexity = total_perplexity / steps

# Distributed-safe: average validation loss across all ranks.
val_loss_tensor = torch.tensor(avg_loss, device=self._engine.device)
if dist.is_available() and dist.is_initialized():
dist.all_reduce(val_loss_tensor, op=dist.ReduceOp.SUM)
val_loss_tensor /= dist.get_world_size()
avg_loss = val_loss_tensor.item()

val_ppl_tensor = torch.tensor(avg_perplexity, device=self._engine.device)
if dist.is_available() and dist.is_initialized():
dist.all_reduce(val_ppl_tensor, op=dist.ReduceOp.SUM)
val_ppl_tensor /= dist.get_world_size()
avg_perplexity = val_ppl_tensor.item()

return avg_loss, avg_perplexity

def _save_checkpoint(self, epoch: int, step: int, global_step: int, **kwargs):
Expand Down
Loading