diff --git a/examples/llm_pte_finetuning/llama3_config.yaml b/examples/llm_pte_finetuning/llama3_config.yaml new file mode 100644 index 00000000000..e4428ff6e01 --- /dev/null +++ b/examples/llm_pte_finetuning/llama3_config.yaml @@ -0,0 +1,94 @@ +output_dir: /tmp/llama-3.2-1B_ft-output # /tmp may be deleted by your system. Change it to your preference. + +# Model Arguments +model: + _component_: torchtune.models.llama3_2.lora_llama3_2_1b + lora_attn_modules: ['q_proj', 'v_proj', 'output_proj'] + apply_lora_to_mlp: True + lora_rank: 64 # higher increases accuracy and memory + lora_alpha: 128 # usually alpha=2*rank + lora_dropout: 0.0 + +# Tokenizer +tokenizer: + _component_: torchtune.models.llama3.llama3_tokenizer + path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model + special_tokens_path: null + max_seq_len: 512 + prompt_template: null + +# Dataset +dataset: + _component_: torchtune.datasets.alpaca_dataset + packed: False # True increases speed +seed: null +shuffle: True + +checkpointer: + _component_: torchtune.training.FullModelHFCheckpointer + checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/ + checkpoint_files: [ + model.safetensors + ] + recipe_checkpoint: null + output_dir: ${output_dir} + model_type: LLAMA3_2 +resume_from_checkpoint: False + +# Fine-tuning arguments +batch_size: 1 +epochs: 1 +optimizer: + _component_: torch.optim.AdamW + fused: True + weight_decay: 0.01 + lr: 3e-4 + +loss: + _component_: torchtune.modules.loss.CEWithChunkedOutputLoss +max_steps_per_epoch: null +gradient_accumulation_steps: 1 # Use to increase effective batch size +optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1 +clip_grad_norm: null +compile: False # torch.compile the model + loss, True increases speed + decreases memory + +# Training environment +device: cpu +dtype: fp32 + +# Memory management +enable_activation_checkpointing: False # True reduces memory +enable_activation_offloading: False # True reduces memory + +# Logging +metric_logger: + _component_: torchtune.training.metric_logging.DiskLogger + log_dir: ${output_dir}/logs +log_every_n_steps: 1 +log_peak_memory_stats: True + + +# Profiler (disabled) +profiler: + _component_: torchtune.training.setup_torch_profiler + enabled: False + + #Output directory of trace artifacts + output_dir: ${output_dir}/profiling_outputs + + #`torch.profiler.ProfilerActivity` types to trace + cpu: True + cuda: True + + #trace options passed to `torch.profiler.profile` + profile_memory: False + with_stack: False + record_shapes: True + with_flops: False + + # `torch.profiler.schedule` options: + # wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat + wait_steps: 5 + warmup_steps: 3 + active_steps: 2 + num_cycles: 1 diff --git a/examples/llm_pte_finetuning/model_exporter.py b/examples/llm_pte_finetuning/model_exporter.py index e7f074c8769..d510dbe199c 100644 --- a/examples/llm_pte_finetuning/model_exporter.py +++ b/examples/llm_pte_finetuning/model_exporter.py @@ -48,7 +48,7 @@ def main() -> None: loss_fn = config.instantiate(cfg.loss) ds = config.instantiate(cfg.dataset, tokenizer) - train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2]) + train_set, _ = torch.utils.data.random_split(ds, [0.8, 0.2]) train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn) max_seq_len = cfg.tokenizer.max_seq_len diff --git a/examples/llm_pte_finetuning/training_lib.py b/examples/llm_pte_finetuning/training_lib.py index f8cae70d39c..edd5e04d291 100644 --- a/examples/llm_pte_finetuning/training_lib.py +++ b/examples/llm_pte_finetuning/training_lib.py @@ -36,13 +36,26 @@ def __init__( super().__init__() self.model = model self.loss = loss + if loss.__class__.__name__ == "CEWithChunkedOutputLoss": + # set num_output_chunks for model + # pyre-ignore + model.set_num_output_chunks(self.loss.num_output_chunks) + + # (batch_size, 1) tensor of ignore_index + # pyre-ignore + self.ignore_labels_cache = torch.full( + (1, 1), self.loss.ignore_index, device="cpu" # pyre-ignore + ) def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: # Output is of the shape (seq_len, vocab_size). logits = self.model(input) - logits = logits[..., :-1, :].contiguous() - labels = labels[..., 1:].contiguous() - logits = logits.transpose(1, 2) + labels = torch.hstack( + (labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]]) + ) + if not isinstance(logits, list): + labels = labels.reshape(-1) + logits = logits.reshape(-1, logits.size(-1)) return self.loss(logits, labels)