diff --git a/apps/sft/llama3_8b.yaml b/apps/sft/llama3_8b.yaml deleted file mode 100644 index 3d67cdb23..000000000 --- a/apps/sft/llama3_8b.yaml +++ /dev/null @@ -1,68 +0,0 @@ - -# profiling: -# enable_profiling: false - -metrics: - logger: tensorboard - freq: - loss: 10 - -# TODO: required by torchtitan -# https://github.com/pytorch/torchtitan/blob/2f1c814da071cc8ad165d00be6f9c1a66f8e1cce/torchtitan/distributed/utils.py#L265 -comm: - trace_buf_size: 0 - -model: - name: llama3 - flavor: 8B - hf_assets_path: /tmp/Meta-Llama-3.1-8B-Instruct - -optimizer: - name: AdamW - lr: 1e-5 - eps: 1e-8 - -lr_scheduler: - warmup_steps: 200 - -training: - local_batch_size: 1 - seq_len: 2048 - max_norm: 1.0 - steps: 1000 - compile: false - -validation: - local_batch_size: 1 - freq: -1 # Change to a positive number to enable validation - steps: 200 # Max steps to run validation. Validation disabled if negative. - -dataset: - path: yahma/alpaca-cleaned - split: train[:95%] - -dataset_val: - path: yahma/alpaca-cleaned - split: train[95%:] - -parallelism: - data_parallel_replicate_degree: 1 - data_parallel_shard_degree: -1 - tensor_parallel_degree: 1 - pipeline_parallel_degree: 1 - context_parallel_degree: 1 - expert_parallel_degree: 1 - disable_loss_parallel: false - -checkpoint: - enable: true - folder: /tmp/Meta-Llama-3.1-8B-Instruct/saved_checkpoints - initial_load_path: /tmp/Meta-Llama-3.1-8B-Instruct/ - initial_load_in_hf: true - last_save_in_hf: true - interval: 500 - async_mode: "disabled" - -activation_checkpoint: - mode: selective - selective_ac_option: op diff --git a/apps/sft/main.py b/apps/sft/main.py deleted file mode 100644 index c806d037a..000000000 --- a/apps/sft/main.py +++ /dev/null @@ -1,313 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import os -import sys -from dataclasses import asdict -from functools import partial -from typing import Any - -import torch - -import torchtitan.experiments.forge.train_spec as forge_train_spec -from forge.cli.config import parse -from forge.data.collate import collate_packed -from forge.data.datasets.packed import PackedDataset, TextPacker -from forge.data.datasets.sft_dataset import AlpacaToMessages, sft_iterable_dataset -from forge.data.tokenizer import HuggingFaceModelTokenizer -from forge.data.utils import batch_to_device, CROSS_ENTROPY_IGNORE_IDX -from forge.util import get_metric_logger - -from omegaconf import DictConfig, OmegaConf -from torch import nn - -from torchdata.stateful_dataloader import StatefulDataLoader -from torchtitan.components.loss import LossFunction -from torchtitan.components.lr_scheduler import LRSchedulersContainer -from torchtitan.components.optimizer import OptimizersContainer -from torchtitan.distributed import ParallelDims, utils as dist_utils -from torchtitan.experiments.forge.engine import ForgeEngine -from torchtitan.experiments.forge.job_config import ForgeJobConfig -from torchtitan.models.attention import init_attention_mask -from tqdm import tqdm - - -# stubs for now -Checkpointer = Any -Dataloader = Any -MetricLogger = Any -Profiler = Any -Tokenizer = Any - - -class ForgeSFTRecipe(ForgeEngine): - job_config: ForgeJobConfig - train_spec: forge_train_spec.ForgeTrainSpec - parallel_dims: ParallelDims - model: list[nn.Module] - loss_fn: LossFunction - optimizer: OptimizersContainer - lr_scheduler: LRSchedulersContainer - checkpointer: Checkpointer - tokenizer: Tokenizer - train_dataloader: Dataloader - # val_dataloader: Dataloader - metric_logger: MetricLogger - profiler: Profiler - device: torch.device - step: int - - def __init__(self, job_config: ForgeJobConfig): - self.current_step = 0 - self.num_training_steps = job_config.training.steps - self.gradient_accumulation_steps = 1 # Example value, adjust as needed - super().__init__(job_config) - self.metric_logger = get_metric_logger(**job_config.metrics) - - def setup(self): - self.train_dataloader = self.setup_data( - self.job_config.dataset, - batch_size=self.job_config.training.local_batch_size, - ) - - self.val_dataloader = self.setup_data( - self.job_config.dataset_val, - batch_size=self.job_config.validation.local_batch_size, - ) - - # self.train_dataloader = self.setup_data( - # self.train_config.train_dataset_config, - # self.train_config.train_dataloader_config, - # self.train_config.packing_config, - # ) - # self.val_dataloader = self.setup_data( - # self.train_config.val_dataset_config, - # self.train_config.val_dataloader_config, - # self.train_config.packing_config, - # ) - - self.checkpointer.load(step=self.current_step) - # self.profiler = self.setup_profiler(self.train_config.profiler_config) - # self.logger = self.setup_logger(self.train_config.logger_config) - - def setup_data(self, dataset_config, batch_size): - self.tokenizer = HuggingFaceModelTokenizer( - tokenizer_json_path=os.path.join( - self.job_config.model.hf_assets_path, "tokenizer.json" - ), - tokenizer_config_json_path=os.path.join( - self.job_config.model.hf_assets_path, "tokenizer_config.json" - ), - generation_config_path=os.path.join( - self.job_config.model.hf_assets_path, "generation_config.json" - ), - ) - - dataset = sft_iterable_dataset( - model_transform=self.tokenizer, - message_transform=AlpacaToMessages(), - path=dataset_config.path, - split=dataset_config.split, - ) - packer = TextPacker(padding_idx=0) - dataset = PackedDataset( - dataset=dataset, - packer=packer, - target_tokens_per_pack=self.job_config.training.seq_len, # TODO: get this from model - ) - dataloader = StatefulDataLoader( - dataset=dataset, - batch_size=batch_size, - collate_fn=partial( - collate_packed, mask_fn=packer.create_block_mask, device=self.device - ), - ) - - # Ultimately we probably want something like this - # packer = build_packing_strategy(packing_config) - # dataset = build_dataset(dataset_config) - # dataloader = build_dataloader(dataloader_config, dataset, packer) - return dataloader - - def forward_backward( - self, - input_dict: dict[str, torch.Tensor], - labels: torch.Tensor, - do_backward: bool = True, - ) -> torch.Tensor: - model_parts = self.model_parts - parallel_dims = self.parallel_dims - - # apply context parallelism if cp is enabled - # ensure CP handles the separate freqs_cis buffer for each pp stage - inputs = input_dict["tokens"] - - if getattr(self.model_args, "use_flex_attn", False): - cp_mesh = ( - parallel_dims.world_mesh["cp"] if parallel_dims.cp_enabled else None - ) - init_attention_mask(inputs, self.tokenizer.base_tokenizer.eos_id, cp_mesh) - - optional_context_parallel_ctx = ( - dist_utils.create_context_parallel_ctx( - cp_mesh=parallel_dims.world_mesh["cp"], - cp_buffers=[inputs, labels] + [m.freqs_cis for m in model_parts], - cp_seq_dims=[1, 1] + [0 for _ in model_parts], - cp_no_restore_buffers={inputs, labels}, - cp_rotate_method=self.job_config.parallelism.context_parallel_rotate_method, - ) - if parallel_dims.cp_enabled - else None - ) - - if parallel_dims.pp_enabled: - # Pipeline Parallel forward / backward inside step() call - with self.train_context(optional_context_parallel_ctx): - targets, losses = ( - (labels, []) if self.pp_has_last_stage else (None, None) - ) - if do_backward: - pp_schedule_fn = self.pp_schedule.step - else: - pp_schedule_fn = self.pp_schedule.eval - if self.pp_has_first_stage: - pp_schedule_fn( - inputs, target=targets, losses=losses, input_batch=inputs - ) - else: - pp_schedule_fn(target=targets, losses=losses, input_batch=inputs) - - # accumulate losses across pipeline microbatches - # TODO: PP+FSDP unexpectedly puts the loss back to the CPU - loss = ( - torch.mean(torch.stack(losses)).to(self.device) - if self.pp_has_last_stage - else torch.tensor([-1.0], device=self.device) - ) - else: - # Non-PP forward / backward - with self.train_context(optional_context_parallel_ctx): - assert len(model_parts) == 1 - with self.maybe_enable_amp: - pred = model_parts[0](inputs) - loss = self.loss_fn(pred, labels) - # need to free to before bwd to avoid peaking memory - del pred - if do_backward: - loss.backward() - - return loss - - def train_step(self, batch) -> None: - # TODO - # with GradientAccumulation( - # self.gradient_accumulation_steps, - # self.model, - # self.data_parallel_size, - # ) as grad_acc: - labels = batch.pop("labels") - loss = self.forward_backward(batch, labels) - self.pbar.update(1) - self.pbar.set_description(f"{self.current_step}|Loss: {loss}") - self.metric_logger.log("loss", loss.item(), self.current_step) - - self.optimizers.step() - self.optimizers.zero_grad() - self.lr_schedulers.step() - - def train(self) -> None: - dataloader = iter(self.train_dataloader) - self.optimizers.zero_grad() - - self.pbar = tqdm( - initial=0, - total=self.num_training_steps, - desc=f"{self.current_step}", - ) - - while self.current_step < self.num_training_steps: - batch = next(dataloader) - # Move tensors to the appropriate device - for k, v in batch.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.to("cuda") # TODO: hardcoded for now - self.train_step(batch) - # self.profiler.step() - self.current_step += 1 - - self.checkpointer.save( - curr_step=self.current_step, - last_step=self.current_step == self.num_training_steps, - ) - - if ( - self.job_config.validation.freq > 0 - and self.job_config.validation.steps > 0 - and self.current_step % self.job_config.validation.freq == 0 - ): - self.validate(self.job_config.validation.steps) - - def validate(self, max_steps: int) -> None: - for m in self.model_parts: - m.eval() - total_val_loss = torch.tensor(0.0, device=self.device) - total_val_tokens = torch.tensor(0.0, device=self.device) - with torch.no_grad(): - val_pbar = tqdm(self.val_dataloader, desc="Validation", leave=False) - for batch_idx, batch in enumerate(val_pbar): - if batch_idx >= max_steps: - break - batch_to_device(batch, self.device) - current_num_tokens = (batch["labels"] != CROSS_ENTROPY_IGNORE_IDX).sum() - # Compute loss - labels = batch.pop("labels") - loss = self.forward_backward(batch, labels, do_backward=False) - val_loss = loss * current_num_tokens - total_val_loss += val_loss - total_val_tokens += current_num_tokens - # Update progress bar description with current average loss - avg_loss_so_far = ( - (total_val_loss / total_val_tokens).item() - if total_val_tokens > 0 - else float("inf") - ) - val_pbar.set_description( - f"Running validation Loss: {avg_loss_so_far:.4f}" - ) - # Aggregate validation metrics across all ranks - torch.distributed.all_reduce(total_val_loss) - torch.distributed.all_reduce(total_val_tokens) - avg_val_loss = ( - (total_val_loss / total_val_tokens).item() - if total_val_tokens > 0 - else float("inf") - ) - for m in self.model_parts: - m.train() - print(f"\nValidation loss: {avg_val_loss}") - - def cleanup(self) -> None: - if self.checkpointer: - self.checkpointer.close() - if self.metric_logger: - self.metric_logger.close() - - -@parse -def recipe_main(cfg: DictConfig) -> None: - # TODO: this is a hack to get the defaults from ForgeJobConfig - default_cfg = ForgeJobConfig() - # Hack to deal with literal types from titan - default_cfg = asdict(default_cfg) - cfg = OmegaConf.merge(default_cfg, cfg) - recipe = ForgeSFTRecipe(cfg) - recipe.setup() - recipe.train() - recipe.cleanup() - - -if __name__ == "__main__": - sys.exit(recipe_main())