diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml index 43a690c1e..f24936670 100644 --- a/apps/sft/llama3_8b.yaml +++ b/apps/sft/llama3_8b.yaml @@ -26,13 +26,27 @@ optimizer: lr_scheduler: warmup_steps: 200 +dataset: + path: "yahma/alpaca-cleaned" + split: "train[:95%]" + +dataset_val: + path: "yahma/alpaca-cleaned" + split: "train[95%:]" + training: local_batch_size: 1 seq_len: 2048 max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + + +validation: + enabled: true # Enable/disable validation + eval_interval: 100 # Run evaluation every 100 training steps + eval_steps: 50 # Number of batches per evaluation (0 = full epoch) + parallelism: data_parallel_replicate_degree: 1 diff --git a/apps/sft/main.py b/apps/sft/main.py index 27a8036d4..c694867fb 100644 --- a/apps/sft/main.py +++ b/apps/sft/main.py @@ -7,7 +7,6 @@ """To run: python -m apps.sft.main --config apps/sft/llama3_8b.yaml - """ import asyncio @@ -40,8 +39,6 @@ from torchtitan.experiments.forge.engine import ForgeEngine from torchtitan.experiments.forge.job_config import ForgeJobConfig -# from tqdm import tqdm - # stubs for now Checkpointer = Any Dataloader = Any @@ -64,7 +61,7 @@ class ForgeSFTRecipe(ForgeActor, ForgeEngine): checkpointer: Checkpointer tokenizer: Tokenizer train_dataloader: Dataloader - # val_dataloader: Dataloader + val_dataloader: Dataloader metric_logger: MetricLogger profiler: Profiler device: torch.device @@ -81,6 +78,27 @@ def __init__(self, config: DictConfig): self.gradient_accumulation_steps = 1 # Example value, adjust as needed self._rank = current_rank().rank self._size = math.prod(current_size().values()) + + # Evaluation settings from validation config + validation_config = job_config.get("validation", {}) + self.validation_enabled = validation_config.get("enabled", False) + + if self.validation_enabled: + self.eval_interval = validation_config.get("eval_interval") + self.eval_steps = validation_config.get("eval_steps") + + if self.eval_interval is None: + raise ValueError( + "validation.eval_interval is required when validation.enabled is true" + ) + if self.eval_steps is None: + raise ValueError( + "validation.eval_steps is required when validation.enabled is true" + ) + else: + self.eval_interval = None + self.eval_steps = None + self._init_dist() super().__init__(job_config) @@ -111,25 +129,30 @@ def _init_dist(self): @endpoint async def setup(self): - self.train_dataloader = self.setup_data() - # self.train_dataloader = self.setup_data( - # self.train_config.train_dataset_config, - # self.train_config.train_dataloader_config, - # self.train_config.packing_config, - # ) - # self.val_dataloader = self.setup_data( - # self.train_config.val_dataset_config, - # self.train_config.val_dataloader_config, - # self.train_config.packing_config, - # ) - - # TODO: confirm that this is working properly - # Should also use load, not dcp_load + # Setup training data from config + dataset_config = self.job_config.get("dataset") + + self.train_dataloader = self.setup_data( + dataset_path=dataset_config.get("path"), + dataset_split=dataset_config.get("split"), + ) + + # Setup validation data from config + dataset_val_config = self.job_config.get("dataset_val", {}) + self.val_dataloader = self.setup_data( + dataset_path=dataset_val_config.get("path", dataset_config.get("path")), + dataset_split=dataset_val_config.get("split", dataset_config.get("split")), + ) + + # Load checkpoint if resuming self.checkpointer.load(step=self.current_step) - # self.profiler = self.setup_profiler(self.train_config.profiler_config) - # self.logger = self.setup_logger(self.train_config.logger_config) - def setup_data(self): + def setup_data(self, dataset_path: str, dataset_split: str): + """Setup data with configurable dataset path and split.""" + if not dataset_path or not dataset_split: + raise ValueError( + f"dataset.path and dataset.split are required in YAML config. Got path={dataset_path}, split={dataset_split}" + ) print(os.path.join(self.job_config.model.hf_assets_path, "tokenizer.json")) tokenizer = HuggingFaceModelTokenizer( tokenizer_json_path=os.path.join( @@ -146,8 +169,8 @@ def setup_data(self): dataset = sft_iterable_dataset( model_transform=tokenizer, message_transform=AlpacaToMessages(), - path="yahma/alpaca-cleaned", - split="train", + path=dataset_path, + split=dataset_split, ) packer = TextPacker(padding_idx=0) dataset = PackedDataset( @@ -163,10 +186,6 @@ def setup_data(self): ), ) - # Ultimately we probably want something like this - # packer = build_packing_strategy(packing_config) - # dataset = build_dataset(dataset_config) - # dataloader = build_dataloader(dataloader_config, dataset, packer) return dataloader def forward_backward( @@ -206,7 +225,6 @@ def forward_backward( ) # accumulate losses across pipeline microbatches - # TODO: PP+FSDP unexpectedly puts the loss back to the CPU loss = ( torch.mean(torch.stack(losses)).to(self.device) if self.pp_has_last_stage @@ -225,27 +243,173 @@ def forward_backward( return loss + def forward_only( + self, input_dict: dict[str, torch.Tensor], labels: torch.Tensor + ) -> torch.Tensor: + """Forward pass only for evaluation (no backward).""" + model_parts = self.model_parts + parallel_dims = self.parallel_dims + + inputs = input_dict["tokens"] + optional_context_parallel_ctx = ( + dist_utils.create_context_parallel_ctx( + cp_mesh=parallel_dims.world_mesh["cp"], + cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], + cp_seq_dims=[1, 1] + [0 for _ in model_parts], + cp_no_restore_buffers={inputs, labels}, + cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, + ) + if parallel_dims.cp_enabled + else None + ) + + if parallel_dims.pp_enabled: + # Pipeline Parallel forward only + with self.train_context(optional_context_parallel_ctx): + targets, losses = ( + (labels, []) if self.pp_has_last_stage else (None, None) + ) + if self.pp_has_first_stage: + self.pp_schedule.step( + inputs, target=targets, losses=losses, input_batch=inputs + ) + else: + self.pp_schedule.step( + target=targets, losses=losses, input_batch=inputs + ) + + loss = ( + torch.mean(torch.stack(losses)).to(self.device) + if self.pp_has_last_stage + else torch.tensor([-1.0], device=self.device) + ) + else: + # Non-PP forward only + with self.train_context(optional_context_parallel_ctx): + assert len(model_parts) == 1 + with self.maybe_enable_amp: + pred = model_parts[0](inputs) + loss = self.loss_fn(pred, labels) + del pred + + return loss + def train_step(self, batch) -> None: - # TODO - # with GradientAccumulation( - # self.gradient_accumulation_steps, - # self.model, - # self.data_parallel_size, - # ) as grad_acc: labels = batch.pop("labels") loss = self.forward_backward(batch, labels) logger.info(f"{self.current_step} / {self.num_training_steps}|Loss: {loss}") - # self.pbar.set_description(f"{self.current_step}|Loss: {loss}") - # self.pbar.update(1) self.optimizers.step() self.lr_schedulers.step() + def _extract_epoch_from_batch(self, batch: dict) -> int | None: + """Extract epoch number from batch metrics.""" + if "metrics" in batch: + for metric in batch["metrics"]: + if ( + hasattr(metric, "metric_name") + and metric.metric_name == "num_epochs" + ): + return metric.value + return None + + async def evaluate(self) -> dict[str, float]: + """Run evaluation with async all_reduce for cross-rank epoch synchronization.""" + logger.info("=" * 50) + logger.info("STARTING EVALUATION") + logger.info("=" * 50) + + for model_part in self.model_parts: + model_part.eval() + + val_dataloader = iter(self.val_dataloader) + total_loss, num_batches, starting_epoch = 0.0, 0, None + + # Prefetch first batch + try: + next_batch = next(val_dataloader) + except StopIteration: + logger.warning("Validation dataloader is empty") + return {"val_loss": 0.0, "val_batches": 0} + + should_break, pending_work, epoch_tensor = False, None, None + + with torch.no_grad(): + while True: + # Wait for previous async all_reduce to complete + if pending_work is not None: + pending_work.wait() + should_break = ( + epoch_tensor.item() > 0 if epoch_tensor is not None else False + ) + pending_work = None + + if should_break: + logger.info( + "Epoch completed across all ranks - stopping evaluation" + ) + break + + if self.eval_steps > 0 and num_batches >= self.eval_steps: + logger.info(f"Reached eval_steps cap of {self.eval_steps}") + break + + batch = next_batch + + # Track starting epoch + current_epoch = self._extract_epoch_from_batch(batch) + if current_epoch is not None and starting_epoch is None: + starting_epoch = current_epoch + + # Prefetch next batch and start async epoch check + try: + next_batch = next(val_dataloader) + next_epoch = self._extract_epoch_from_batch(next_batch) + + if next_epoch is not None and starting_epoch is not None: + epoch_increment = next_epoch - starting_epoch + if torch.distributed.is_initialized(): + epoch_tensor = torch.tensor( + [epoch_increment], dtype=torch.long, device=self.device + ) + pending_work = torch.distributed.all_reduce( + epoch_tensor, + op=torch.distributed.ReduceOp.MAX, + async_op=True, + ) + else: + should_break = epoch_increment > 0 + except StopIteration: + should_break = True + + # Process current batch (overlaps with async all_reduce) + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(self.device) + + labels = batch.pop("labels") + loss = self.forward_only(batch, labels) + total_loss += loss.item() + num_batches += 1 + + if num_batches % 10 == 0: + logger.info(f" Eval batch {num_batches} | Loss: {loss.item():.4f}") + + for model_part in self.model_parts: + model_part.train() + + avg_loss = total_loss / max(num_batches, 1) + logger.info( + f"EVALUATION COMPLETE | Val Loss: {avg_loss:.4f} | Batches: {num_batches}" + ) + logger.info("=" * 50) + + return {"val_loss": avg_loss, "val_batches": num_batches} + @endpoint async def train(self) -> None: dataloader = iter(self.train_dataloader) self.optimizers.zero_grad() - # TODO: tqdm is broken in Monarch actors # self.pbar = tqdm(initial=self.current_step, total=self.num_training_steps) @@ -254,18 +418,21 @@ async def train(self) -> None: # Move tensors to the appropriate device for k, v in batch.items(): if isinstance(v, torch.Tensor): - batch[k] = v.to("cuda") # TODO: hardcoded for now + batch[k] = v.to(self.device) # TODO: hardcoded for now self.train_step(batch) - # self.profiler.step() self.current_step += 1 + # Run evaluation periodically if enabled + if self.validation_enabled and self.current_step % self.eval_interval == 0: + eval_metrics = await self.evaluate() + logger.info(f"Step {self.current_step} | Eval metrics: {eval_metrics}") + + # Save checkpoints self.checkpointer.save( curr_step=self.current_step, last_step=self.current_step == self.num_training_steps, ) - # self.pbar.close() - @endpoint async def cleanup(self) -> None: if self.checkpointer: diff --git a/apps/sft/qwen3_8b.yaml b/apps/sft/qwen3_8b.yaml index 2ab88bbd3..2d4128065 100644 --- a/apps/sft/qwen3_8b.yaml +++ b/apps/sft/qwen3_8b.yaml @@ -25,13 +25,26 @@ optimizer: lr_scheduler: warmup_steps: 200 +# Dataset configuration +dataset: + path: "yahma/alpaca-cleaned" + split: "train[:95%]" + +dataset_val: + path: "yahma/alpaca-cleaned" + split: "train[95%:]" + training: local_batch_size: 1 seq_len: 2048 max_norm: 1.0 steps: 1000 compile: false - dataset: "c4" + +validation: + enabled: true # Enable/disable validation + eval_interval: 100 # Run evaluation every 100 training steps + eval_steps: 50 # Number of batches per evaluation (0 = full epoch) parallelism: data_parallel_replicate_degree: 1 diff --git a/apps/sft/test_evaluate.py b/apps/sft/test_evaluate.py new file mode 100644 index 000000000..57959b09d --- /dev/null +++ b/apps/sft/test_evaluate.py @@ -0,0 +1,437 @@ +""" +Tests for the non-blocking all_reduce evaluation logic in main.py + +This tests the epoch-detection and async all_reduce pattern used to +synchronize evaluation completion across multiple ranks without blocking. +""" + +from dataclasses import dataclass +from unittest.mock import MagicMock, Mock, patch + +import pytest +import torch + + +@dataclass +class MockMetric: + """Mock metric object matching the structure in batch["metrics"]""" + + metric_name: str + value: int + + +class MockTrainer: + """Mock trainer with minimal setup for testing evaluate logic""" + + def __init__(self, eval_steps=0): + self.eval_steps = eval_steps + self.device = torch.device("cpu") + self.model_parts = [Mock()] + + def _extract_epoch_from_batch(self, batch: dict) -> int | None: + """Extract epoch number from batch metrics.""" + if "metrics" not in batch: + return None + + for metric in batch["metrics"]: + if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs": + return metric.value + return None + + def forward_only(self, batch, labels): + """Mock forward pass - returns dummy loss""" + return torch.tensor(1.5) + + +def create_batch_with_epoch(epoch: int, loss_value: float = 1.5): + """Helper to create a mock batch with epoch metadata""" + return { + "input_ids": torch.randn(2, 10), + "attention_mask": torch.ones(2, 10), + "labels": torch.randint(0, 100, (2, 10)), + "metrics": [MockMetric(metric_name="num_epochs", value=epoch)], + } + + +def create_batch_without_epoch(loss_value: float = 1.5): + """Helper to create a mock batch without epoch metadata""" + return { + "input_ids": torch.randn(2, 10), + "attention_mask": torch.ones(2, 10), + "labels": torch.randint(0, 100, (2, 10)), + } + + +class TestExtractEpochFromBatch: + """Test the _extract_epoch_from_batch helper method""" + + def test_extract_epoch_success(self): + """Test extracting epoch from batch with proper metadata""" + trainer = MockTrainer() + batch = create_batch_with_epoch(epoch=5) + + epoch = trainer._extract_epoch_from_batch(batch) + assert epoch == 5 + + def test_extract_epoch_no_metrics(self): + """Test batch without metrics returns None""" + trainer = MockTrainer() + batch = create_batch_without_epoch() + + epoch = trainer._extract_epoch_from_batch(batch) + assert epoch is None + + def test_extract_epoch_wrong_metric_name(self): + """Test batch with metrics but wrong metric_name returns None""" + trainer = MockTrainer() + batch = { + "input_ids": torch.randn(2, 10), + "metrics": [MockMetric(metric_name="other_metric", value=10)], + } + + epoch = trainer._extract_epoch_from_batch(batch) + assert epoch is None + + def test_extract_epoch_multiple_metrics(self): + """Test extracting epoch from batch with multiple metrics""" + trainer = MockTrainer() + batch = { + "input_ids": torch.randn(2, 10), + "metrics": [ + MockMetric(metric_name="loss", value=1.5), + MockMetric(metric_name="num_epochs", value=3), + MockMetric(metric_name="step", value=100), + ], + } + + epoch = trainer._extract_epoch_from_batch(batch) + assert epoch == 3 + + +class TestEvaluationLogic: + """Test the evaluation loop logic (single-rank scenario)""" + + @pytest.mark.asyncio + async def test_single_epoch_completion(self): + """Test that evaluation stops after one complete epoch""" + trainer = MockTrainer(eval_steps=0) # No cap + + # Create batches: 3 from epoch 0, then epoch increments to 1 + batches = [ + create_batch_with_epoch(0), + create_batch_with_epoch(0), + create_batch_with_epoch(0), + create_batch_with_epoch(1), # Epoch increment - should trigger stop + ] + + dataloader = iter(batches) + + # Simulate the evaluation pattern + num_processed = 0 + starting_epoch = None + next_should_break = False + + # Get first batch + next_batch = next(dataloader) + + while True: + if next_should_break: + break + + batch = next_batch + + # Extract epoch from current batch + current_epoch = trainer._extract_epoch_from_batch(batch) + if current_epoch is not None and starting_epoch is None: + starting_epoch = current_epoch + + # Try to prefetch next batch + try: + next_batch = next(dataloader) + next_epoch = trainer._extract_epoch_from_batch(next_batch) + + # Check for epoch increment + if next_epoch is not None and starting_epoch is not None: + epoch_increment = next_epoch - starting_epoch + next_should_break = epoch_increment > 0 + + except StopIteration: + next_should_break = True + + # Process current batch + num_processed += 1 + + # Should have processed 3 batches (stopped when detected epoch 1) + assert num_processed == 3 + assert starting_epoch == 0 + + @pytest.mark.asyncio + async def test_eval_steps_cap(self): + """Test that evaluation respects eval_steps cap""" + trainer = MockTrainer(eval_steps=2) # Cap at 2 batches + + # Create 5 batches all in same epoch + batches = [create_batch_with_epoch(0) for _ in range(5)] + dataloader = iter(batches) + + # Simulate the evaluation pattern + num_processed = 0 + next_should_break = False + + # Get first batch + next_batch = next(dataloader) + + while True: + if next_should_break: + break + + # Check eval_steps cap + if trainer.eval_steps > 0 and num_processed >= trainer.eval_steps: + break + + batch = next_batch + + # Try to prefetch next batch + try: + next_batch = next(dataloader) + except StopIteration: + next_should_break = True + + # Process current batch + num_processed += 1 + + # Should have processed exactly 2 batches (eval_steps cap) + assert num_processed == 2 + + @pytest.mark.asyncio + async def test_empty_dataloader(self): + """Test handling of empty dataloader""" + trainer = MockTrainer(eval_steps=0) + + batches = [] + dataloader = iter(batches) + + # Should raise StopIteration immediately + with pytest.raises(StopIteration): + next_batch = next(dataloader) + + @pytest.mark.asyncio + async def test_single_batch(self): + """Test evaluation with only one batch""" + trainer = MockTrainer(eval_steps=0) + + batches = [create_batch_with_epoch(0)] + dataloader = iter(batches) + + num_processed = 0 + next_should_break = False + + # Get first batch + next_batch = next(dataloader) + + while True: + if next_should_break: + break + + batch = next_batch + + # Try to prefetch next batch + try: + next_batch = next(dataloader) + except StopIteration: + next_should_break = True + + # Process current batch + num_processed += 1 + + # Should have processed 1 batch + assert num_processed == 1 + + @pytest.mark.asyncio + async def test_no_epoch_metadata(self): + """Test evaluation when batches don't have epoch metadata""" + trainer = MockTrainer(eval_steps=3) # Use eval_steps as fallback + + # Create batches without epoch metadata + batches = [create_batch_without_epoch() for _ in range(5)] + dataloader = iter(batches) + + num_processed = 0 + next_should_break = False + next_batch = next(dataloader) + + while True: + if next_should_break: + break + + # Check eval_steps cap (should be the stopping condition) + if trainer.eval_steps > 0 and num_processed >= trainer.eval_steps: + break + + batch = next_batch + + try: + next_batch = next(dataloader) + except StopIteration: + next_should_break = True + + num_processed += 1 + + # Should stop at eval_steps + assert num_processed == 3 + + +class TestAsyncAllReduce: + """Test the async all_reduce pattern with mocked distributed operations""" + + @pytest.mark.asyncio + async def test_async_all_reduce_pattern(self): + """Test the async all_reduce pattern with mock distributed operations""" + + # Mock distributed environment + with patch("torch.distributed.is_initialized", return_value=True): + with patch("torch.distributed.all_reduce") as mock_all_reduce: + + # Create mock Work handle for async operation + mock_work = Mock() + mock_work.wait = Mock() + mock_all_reduce.return_value = mock_work + + trainer = MockTrainer(eval_steps=0) + + # Simulate the async pattern + epoch_tensor = torch.tensor([0], dtype=torch.long) + + # Start async all_reduce (should return immediately) + work_handle = torch.distributed.all_reduce( + epoch_tensor, op=torch.distributed.ReduceOp.MAX, async_op=True + ) + + # Verify it returned immediately with a work handle + assert work_handle is not None + assert mock_all_reduce.called + + # Simulate doing computation here... + + # Wait for completion + work_handle.wait() + assert mock_work.wait.called + + @pytest.mark.asyncio + async def test_multi_rank_epoch_detection(self): + """Test that epoch completion is detected when ANY rank finishes""" + + with patch("torch.distributed.is_initialized", return_value=True): + with patch("torch.distributed.all_reduce") as mock_all_reduce: + + def all_reduce_side_effect(tensor, op, async_op=False): + """Simulate all_reduce MAX operation across ranks + Rank 0: epoch_increment = 0 (still in epoch 0) + Rank 1: epoch_increment = 1 (moved to epoch 1) + MAX = 1, so all ranks should stop + """ + # Simulate MAX operation - set tensor to max value + tensor[0] = 1 # At least one rank has epoch_increment=1 + + if async_op: + mock_work = Mock() + mock_work.wait = Mock() + return mock_work + return None + + mock_all_reduce.side_effect = all_reduce_side_effect + + trainer = MockTrainer(eval_steps=0) + + # Simulate rank 1's perspective: it moved to epoch 1 + starting_epoch = 0 + next_epoch = 1 + epoch_increment = next_epoch - starting_epoch # = 1 + + epoch_tensor = torch.tensor([epoch_increment], dtype=torch.long) + + # Start async all_reduce + work = torch.distributed.all_reduce( + epoch_tensor, op=torch.distributed.ReduceOp.MAX, async_op=True + ) + + # Wait for result + work.wait() + + # Check if should break (any rank has increment > 0) + should_break = epoch_tensor.item() > 0 + + assert should_break is True + assert epoch_tensor.item() == 1 + + +class TestEvaluationIntegration: + """Integration-style tests for the full evaluation flow""" + + @pytest.mark.asyncio + async def test_prefetch_pattern_ordering(self): + """Test that the prefetch pattern processes batches in correct order""" + trainer = MockTrainer(eval_steps=0) + + # Create identifiable batches + batches = [ + { + "id": 0, + "metrics": [MockMetric("num_epochs", 0)], + "labels": torch.zeros(1), + }, + { + "id": 1, + "metrics": [MockMetric("num_epochs", 0)], + "labels": torch.zeros(1), + }, + { + "id": 2, + "metrics": [MockMetric("num_epochs", 0)], + "labels": torch.zeros(1), + }, + { + "id": 3, + "metrics": [MockMetric("num_epochs", 1)], + "labels": torch.zeros(1), + }, + ] + + dataloader = iter(batches) + processed_ids = [] + + # Prefetch first batch + next_batch = next(dataloader) + next_should_break = False + starting_epoch = None + + while True: + if next_should_break: + break + + # Process current batch + batch = next_batch + processed_ids.append(batch["id"]) + + # Extract epoch + current_epoch = trainer._extract_epoch_from_batch(batch) + if current_epoch is not None and starting_epoch is None: + starting_epoch = current_epoch + + # Prefetch next + try: + next_batch = next(dataloader) + next_epoch = trainer._extract_epoch_from_batch(next_batch) + + if next_epoch is not None and starting_epoch is not None: + epoch_increment = next_epoch - starting_epoch + next_should_break = epoch_increment > 0 + except StopIteration: + next_should_break = True + + # Should have processed batches 0, 1, 2 (stopped when detected batch 3 has epoch 1) + assert processed_ids == [0, 1, 2] + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])