Skip to content

Commit 4793948

Browse files
author
Hossein Kavianihamedani
committed
Add configurable datasets and validation and shortening the code
1 parent 53371c6 commit 4793948

File tree

3 files changed

+94
-86
lines changed

3 files changed

+94
-86
lines changed

apps/sft/llama3_8b.yaml

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,27 @@ optimizer:
2626
lr_scheduler:
2727
warmup_steps: 200
2828

29+
dataset:
30+
path: "yahma/alpaca-cleaned"
31+
split: "train[:95%]"
32+
33+
dataset_val:
34+
path: "yahma/alpaca-cleaned"
35+
split: "train[95%:]"
36+
2937
training:
3038
local_batch_size: 1
3139
seq_len: 2048
3240
max_norm: 1.0
3341
steps: 1000
3442
compile: false
35-
dataset: "c4"
36-
#eval_interval: 500 # Setting eval_interval to run evaluation
37-
#eval_steps: 100 # Number of validation batches during each evaluation run
43+
44+
45+
validation:
46+
enabled: true # Enable/disable validation
47+
eval_interval: 100 # Run evaluation every 100 training steps
48+
eval_steps: 50 # Number of batches per evaluation (0 = full epoch)
49+
3850

3951
parallelism:
4052
data_parallel_replicate_degree: 1

apps/sft/main.py

Lines changed: 65 additions & 82 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,25 @@ def __init__(self, config: DictConfig):
7979
self._rank = current_rank().rank
8080
self._size = math.prod(current_size().values())
8181

82-
# Evaluation settings
83-
self.eval_interval = job_config.training.get("eval_interval", float("inf"))
84-
self.eval_steps = job_config.training.get("eval_steps", 0)
82+
# Evaluation settings from validation config
83+
validation_config = job_config.get("validation", {})
84+
self.validation_enabled = validation_config.get("enabled", False)
85+
86+
if self.validation_enabled:
87+
self.eval_interval = validation_config.get("eval_interval")
88+
self.eval_steps = validation_config.get("eval_steps")
89+
90+
if self.eval_interval is None:
91+
raise ValueError(
92+
"validation.eval_interval is required when validation.enabled is true"
93+
)
94+
if self.eval_steps is None:
95+
raise ValueError(
96+
"validation.eval_steps is required when validation.enabled is true"
97+
)
98+
else:
99+
self.eval_interval = None
100+
self.eval_steps = None
85101

86102
self._init_dist()
87103
super().__init__(job_config)
@@ -113,23 +129,30 @@ def _init_dist(self):
113129

114130
@endpoint
115131
async def setup(self):
116-
# Setup training data (first 90% of train split)
132+
# Setup training data from config
133+
dataset_config = self.job_config.get("dataset")
134+
117135
self.train_dataloader = self.setup_data(
118-
dataset_path="yahma/alpaca-cleaned", dataset_split="train[:90%]"
136+
dataset_path=dataset_config.get("path"),
137+
dataset_split=dataset_config.get("split"),
119138
)
120139

121-
# Setup validation data (last 10% of train split)
140+
# Setup validation data from config
141+
dataset_val_config = self.job_config.get("dataset_val", {})
122142
self.val_dataloader = self.setup_data(
123-
dataset_path="yahma/alpaca-cleaned", dataset_split="train[90%:]"
143+
dataset_path=dataset_val_config.get("path", dataset_config.get("path")),
144+
dataset_split=dataset_val_config.get("split", dataset_config.get("split")),
124145
)
125146

126147
# Load checkpoint if resuming
127148
self.checkpointer.load(step=self.current_step)
128149

129-
def setup_data(
130-
self, dataset_path: str = "yahma/alpaca-cleaned", dataset_split: str = "train"
131-
):
150+
def setup_data(self, dataset_path: str, dataset_split: str):
132151
"""Setup data with configurable dataset path and split."""
152+
if not dataset_path or not dataset_split:
153+
raise ValueError(
154+
f"dataset.path and dataset.split are required in YAML config. Got path={dataset_path}, split={dataset_split}"
155+
)
133156
print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json"))
134157
tokenizer = HuggingFaceModelTokenizer(
135158
tokenizer_json_path=os.path.join(
@@ -281,39 +304,26 @@ def train_step(self, batch) -> None:
281304

282305
def _extract_epoch_from_batch(self, batch: dict) -> int | None:
283306
"""Extract epoch number from batch metrics."""
284-
if "metrics" not in batch:
285-
return None
286-
287-
for metric in batch["metrics"]:
288-
if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs":
289-
return metric.value
307+
if "metrics" in batch:
308+
for metric in batch["metrics"]:
309+
if (
310+
hasattr(metric, "metric_name")
311+
and metric.metric_name == "num_epochs"
312+
):
313+
return metric.value
290314
return None
291315

292316
async def evaluate(self) -> dict[str, float]:
293-
"""Run evaluation on validation set for one complete epoch.
294-
295-
Uses prefetch + non-blocking all_reduce pattern to detect epoch completion
296-
across all ranks without blocking on every batch.
297-
298-
Pattern:
299-
- Iteration N: Start async all_reduce on next batch's epoch (non-blocking)
300-
- Process current batch while all_reduce completes in background
301-
- Iteration N+1: Check result from previous all_reduce (should be done)
302-
303-
This overlaps communication with computation for better performance.
304-
"""
317+
"""Run evaluation with async all_reduce for cross-rank epoch synchronization."""
305318
logger.info("=" * 50)
306-
logger.info("STARTING EVALUATION ")
319+
logger.info("STARTING EVALUATION")
307320
logger.info("=" * 50)
308321

309-
# Set model to eval mode
310322
for model_part in self.model_parts:
311323
model_part.eval()
312324

313325
val_dataloader = iter(self.val_dataloader)
314-
total_loss = 0.0
315-
num_batches = 0
316-
starting_epoch = None
326+
total_loss, num_batches, starting_epoch = 0.0, 0, None
317327

318328
# Prefetch first batch
319329
try:
@@ -322,106 +332,79 @@ async def evaluate(self) -> dict[str, float]:
322332
logger.warning("Validation dataloader is empty")
323333
return {"val_loss": 0.0, "val_batches": 0}
324334

325-
next_should_break = False
326-
pending_work = None # Handle for async all_reduce
327-
epoch_tensor = None # Tensor for all_reduce result
335+
should_break, pending_work, epoch_tensor = False, None, None
328336

329337
with torch.no_grad():
330338
while True:
331-
# Check result from PREVIOUS iteration's async all_reduce
339+
# Wait for previous async all_reduce to complete
332340
if pending_work is not None:
333-
pending_work.wait() # Should be complete (or very fast) since we did compute
334-
if epoch_tensor is not None:
335-
next_should_break = epoch_tensor.item() > 0
341+
pending_work.wait()
342+
should_break = (
343+
epoch_tensor.item() > 0 if epoch_tensor is not None else False
344+
)
336345
pending_work = None
337346

338-
# Check if we should break (based on previous iteration's check)
339-
if next_should_break:
347+
if should_break:
340348
logger.info(
341349
"Epoch completed across all ranks - stopping evaluation"
342350
)
343351
break
344352

345-
# Check optional cap on eval steps
346353
if self.eval_steps > 0 and num_batches >= self.eval_steps:
347354
logger.info(f"Reached eval_steps cap of {self.eval_steps}")
348355
break
349356

350-
# Use the batch that was prefetched in previous iteration
351357
batch = next_batch
352358

353-
# Extract epoch from current batch
359+
# Track starting epoch
354360
current_epoch = self._extract_epoch_from_batch(batch)
355361
if current_epoch is not None and starting_epoch is None:
356362
starting_epoch = current_epoch
357-
logger.info(f"Starting evaluation at epoch {starting_epoch}")
358363

359-
# Prefetch next batch and start async all_reduce
364+
# Prefetch next batch and start async epoch check
360365
try:
361366
next_batch = next(val_dataloader)
362-
363-
# Extract epoch from next batch
364367
next_epoch = self._extract_epoch_from_batch(next_batch)
365368

366-
# Start NON-BLOCKING all_reduce to check if any rank completed epoch
367369
if next_epoch is not None and starting_epoch is not None:
368-
# Check if next batch indicates epoch completion
369370
epoch_increment = next_epoch - starting_epoch
370-
371371
if torch.distributed.is_initialized():
372-
# Create tensor for all_reduce
373372
epoch_tensor = torch.tensor(
374373
[epoch_increment], dtype=torch.long, device=self.device
375374
)
376-
# Start async all_reduce (returns immediately, doesn't block)
377375
pending_work = torch.distributed.all_reduce(
378376
epoch_tensor,
379377
op=torch.distributed.ReduceOp.MAX,
380-
async_op=True, # NON-BLOCKING - returns immediately
378+
async_op=True,
381379
)
382380
else:
383-
# Single rank case - just check locally
384-
next_should_break = epoch_increment > 0
385-
381+
should_break = epoch_increment > 0
386382
except StopIteration:
387-
# No more batches - this is the last one
388-
next_should_break = True
383+
should_break = True
389384

390-
# Process current batch (while all_reduce completes in background)
391-
# Move tensors to device
385+
# Process current batch (overlaps with async all_reduce)
392386
for k, v in batch.items():
393387
if isinstance(v, torch.Tensor):
394388
batch[k] = v.to(self.device)
395389

396390
labels = batch.pop("labels")
397391
loss = self.forward_only(batch, labels)
398-
# GPU compute happens here while network does all_reduce
399-
400392
total_loss += loss.item()
401393
num_batches += 1
402394

403-
eval_steps_info = f"/{self.eval_steps}" if self.eval_steps > 0 else ""
404-
logger.info(
405-
f" Eval batch {num_batches}{eval_steps_info} | Loss: {loss.item():.4f}"
406-
)
395+
if num_batches % 10 == 0:
396+
logger.info(f" Eval batch {num_batches} | Loss: {loss.item():.4f}")
407397

408-
# Set model back to train mode
409398
for model_part in self.model_parts:
410399
model_part.train()
411400

412401
avg_loss = total_loss / max(num_batches, 1)
413-
414-
metrics = {
415-
"val_loss": avg_loss,
416-
"val_batches": num_batches,
417-
}
418-
419-
logger.info("-" * 50)
420-
logger.info(f"EVALUATION COMPLETE")
421-
logger.info(f"Validation Loss: {avg_loss:.4f}")
422-
logger.info(f"Batches Evaluated: {num_batches}")
402+
logger.info(
403+
f"EVALUATION COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}"
404+
)
423405
logger.info("=" * 50)
424-
return metrics
406+
407+
return {"val_loss": avg_loss, "val_batches": num_batches}
425408

426409
@endpoint
427410
async def train(self) -> None:
@@ -439,8 +422,8 @@ async def train(self) -> None:
439422
self.train_step(batch)
440423
self.current_step += 1
441424

442-
# Run evaluation periodically
443-
if self.current_step % self.eval_interval == 0:
425+
# Run evaluation periodically if enabled
426+
if self.validation_enabled and self.current_step % self.eval_interval == 0:
444427
eval_metrics = await self.evaluate()
445428
logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}")
446429

apps/sft/qwen3_8b.yaml

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,13 +25,26 @@ optimizer:
2525
lr_scheduler:
2626
warmup_steps: 200
2727

28+
# Dataset configuration
29+
dataset:
30+
path: "yahma/alpaca-cleaned"
31+
split: "train[:95%]"
32+
33+
dataset_val:
34+
path: "yahma/alpaca-cleaned"
35+
split: "train[95%:]"
36+
2837
training:
2938
local_batch_size: 1
3039
seq_len: 2048
3140
max_norm: 1.0
3241
steps: 1000
3342
compile: false
34-
dataset: "c4"
43+
44+
validation:
45+
enabled: true # Enable/disable validation
46+
eval_interval: 100 # Run evaluation every 100 training steps
47+
eval_steps: 50 # Number of batches per evaluation (0 = full epoch)
3548

3649
parallelism:
3750
data_parallel_replicate_degree: 1

0 commit comments

Comments
 (0)