Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 94 additions & 0 deletions examples/llm_pte_finetuning/llama3_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
output_dir: /tmp/llama-3.2-1B_ft-output # /tmp may be deleted by your system. Change it to your preference.

# Model Arguments
model:
_component_: torchtune.models.llama3_2.lora_llama3_2_1b
lora_attn_modules: ['q_proj', 'v_proj', 'output_proj']
apply_lora_to_mlp: True
lora_rank: 64 # higher increases accuracy and memory
lora_alpha: 128 # usually alpha=2*rank
lora_dropout: 0.0

# Tokenizer
tokenizer:
_component_: torchtune.models.llama3.llama3_tokenizer
path: /tmp/Llama-3.2-1B-Instruct/original/tokenizer.model
special_tokens_path: null
max_seq_len: 512
prompt_template: null

# Dataset
dataset:
_component_: torchtune.datasets.alpaca_dataset
packed: False # True increases speed
seed: null
shuffle: True

checkpointer:
_component_: torchtune.training.FullModelHFCheckpointer
checkpoint_dir: /tmp/Llama-3.2-1B-Instruct/
checkpoint_files: [
model.safetensors
]
recipe_checkpoint: null
output_dir: ${output_dir}
model_type: LLAMA3_2
resume_from_checkpoint: False

# Fine-tuning arguments
batch_size: 1
epochs: 1
optimizer:
_component_: torch.optim.AdamW
fused: True
weight_decay: 0.01
lr: 3e-4

loss:
_component_: torchtune.modules.loss.CEWithChunkedOutputLoss
max_steps_per_epoch: null
gradient_accumulation_steps: 1 # Use to increase effective batch size
optimizer_in_bwd: True # True saves memory. Requires gradient_accumulation_steps=1
clip_grad_norm: null
compile: False # torch.compile the model + loss, True increases speed + decreases memory

# Training environment
device: cpu
dtype: fp32

# Memory management
enable_activation_checkpointing: False # True reduces memory
enable_activation_offloading: False # True reduces memory

# Logging
metric_logger:
_component_: torchtune.training.metric_logging.DiskLogger
log_dir: ${output_dir}/logs
log_every_n_steps: 1
log_peak_memory_stats: True


# Profiler (disabled)
profiler:
_component_: torchtune.training.setup_torch_profiler
enabled: False

#Output directory of trace artifacts
output_dir: ${output_dir}/profiling_outputs

#`torch.profiler.ProfilerActivity` types to trace
cpu: True
cuda: True

#trace options passed to `torch.profiler.profile`
profile_memory: False
with_stack: False
record_shapes: True
with_flops: False

# `torch.profiler.schedule` options:
# wait_steps -> wait, warmup_steps -> warmup, active_steps -> active, num_cycles -> repeat
wait_steps: 5
warmup_steps: 3
active_steps: 2
num_cycles: 1
2 changes: 1 addition & 1 deletion examples/llm_pte_finetuning/model_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def main() -> None:
loss_fn = config.instantiate(cfg.loss)

ds = config.instantiate(cfg.dataset, tokenizer)
train_set, val_set = torch.utils.data.random_split(ds, [0.8, 0.2])
train_set, _ = torch.utils.data.random_split(ds, [0.8, 0.2])
train_dataloader = get_dataloader(cfg, train_set, tokenizer, loss_fn)

max_seq_len = cfg.tokenizer.max_seq_len
Expand Down
19 changes: 16 additions & 3 deletions examples/llm_pte_finetuning/training_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,13 +36,26 @@ def __init__(
super().__init__()
self.model = model
self.loss = loss
if loss.__class__.__name__ == "CEWithChunkedOutputLoss":
# set num_output_chunks for model
# pyre-ignore
model.set_num_output_chunks(self.loss.num_output_chunks)

# (batch_size, 1) tensor of ignore_index
# pyre-ignore
self.ignore_labels_cache = torch.full(
(1, 1), self.loss.ignore_index, device="cpu" # pyre-ignore
)

def forward(self, input: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
# Output is of the shape (seq_len, vocab_size).
logits = self.model(input)
logits = logits[..., :-1, :].contiguous()
labels = labels[..., 1:].contiguous()
logits = logits.transpose(1, 2)
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
return self.loss(logits, labels)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we return the prediction too?



Expand Down
Loading