Skip to content

Commit bf9efe2

Browse files
authored
[llama4] fix the mm yaml, add scout single gpu yaml (axolotl-ai-cloud#2510)
* [llama4] fix the mm yaml, add scout single gpu yaml * add README for llama4 * rename to specify fsdp
1 parent 0dac2dd commit bf9efe2

File tree

3 files changed

+133
-23
lines changed

3 files changed

+133
-23
lines changed

examples/llama-4/README.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Llama 4 by Meta AI
2+
3+
## Available Examples
4+
5+
### Llama 4 Scout 17Bx16Experts (109B)
6+
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
7+
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
8+
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
9+
10+
Our Single GPU implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second.
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
2+
model_type: Llama4ForConditionalGeneration
3+
# Automatically upload checkpoint and final model to HF
4+
# hub_model_id: username/custom_model_name
5+
6+
strict: false
7+
8+
plugins:
9+
- axolotl.integrations.liger.LigerPlugin
10+
11+
liger_glu_activation: true
12+
liger_rms_norm: true
13+
liger_layer_norm: true
14+
15+
llama4_linearized_experts: true
16+
load_in_4bit: true
17+
adapter: qlora
18+
lora_r: 32
19+
lora_alpha: 64
20+
lora_target_modules:
21+
- self_attn.q_proj
22+
- self_attn.k_proj
23+
- self_attn.v_proj
24+
- self_attn.o_proj
25+
- shared_expert.gate_proj
26+
- shared_expert.up_proj
27+
- shared_expert.down_proj
28+
# - experts.gate_projs.[0-9]+$
29+
# - experts.up_projs.[0-9]+$
30+
# - experts.down_projs.[0-9]+$
31+
lora_modules_to_save:
32+
# - lm_head
33+
# - embed_tokens
34+
35+
lora_mlp_kernel: true
36+
lora_qkv_kernel: true
37+
lora_o_kernel: true
38+
39+
chat_template: llama4
40+
datasets:
41+
- path: mlabonne/FineTome-100k
42+
type: chat_template
43+
split: train[:20%]
44+
field_messages: conversations
45+
message_property_mappings:
46+
role: from
47+
content: value
48+
49+
dataset_prepared_path: last_run_prepared
50+
val_set_size: 0.0
51+
output_dir: ./outputs/out
52+
53+
sequence_len: 4096 # up to 8k will work on a single H100
54+
sample_packing: true
55+
pad_to_sequence_len: true
56+
57+
wandb_project:
58+
wandb_entity:
59+
wandb_watch:
60+
wandb_name:
61+
wandb_log_model:
62+
63+
gradient_accumulation_steps: 1
64+
micro_batch_size: 1
65+
num_epochs: 1
66+
optimizer: adamw_torch_4bit
67+
lr_scheduler: cosine
68+
learning_rate: 1e-4
69+
70+
bf16: true
71+
tf32: true
72+
73+
logging_steps: 1
74+
flash_attention: true
75+
76+
gradient_checkpointing: offload
77+
gradient_checkpointing_kwargs:
78+
use_reentrant: false
79+
80+
warmup_steps: 20
81+
evals_per_epoch: 1
82+
saves_per_epoch: 1
83+
weight_decay: 0.0
84+
special_tokens:
85+
pad_token: <|finetune_right_pad_id|>
86+
eos_token: <|eot|>

examples/llama-4/scout-lora.yaml renamed to examples/llama-4/scout-vision-qlora-fsdp.yaml

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,88 @@
1-
base_model: meta-llama/Llama-4-Scout-17B-16E
1+
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
22
model_type: Llama4ForConditionalGeneration
3+
processor_type: Llama4Processor
34
# Automatically upload checkpoint and final model to HF
45
# hub_model_id: username/custom_model_name
56

67
strict: false
78

8-
# torch_compile: true
9+
# these 3 lines are needed for now to handle vision chat templates w images
10+
skip_prepare_dataset: true
11+
remove_unused_columns: false
12+
sample_packing: false
913

10-
adapter: lora
14+
sequence_len: 4096
15+
16+
plugins:
17+
- axolotl.integrations.liger.LigerPlugin
18+
19+
liger_glu_activation: true
20+
liger_rms_norm: true
21+
liger_layer_norm: true
22+
23+
llama4_linearized_experts: true # use Axolotl's customized model
24+
load_in_4bit: true
25+
adapter: qlora
1126
lora_r: 32
1227
lora_alpha: 64
1328
lora_target_modules:
1429
- self_attn.q_proj
1530
- self_attn.k_proj
1631
- self_attn.v_proj
1732
- self_attn.o_proj
33+
- shared_expert.gate_proj
34+
- shared_expert.up_proj
35+
- shared_expert.down_proj
36+
- vision_adapter.mlp.fc1
37+
- vision_adapter.mlp.fc2
38+
# - experts.gate_projs.[0-9]+$
39+
# - experts.up_projs.[0-9]+$
40+
# - experts.down_projs.[0-9]+$
1841
lora_modules_to_save:
1942
- lm_head
2043
- embed_tokens
2144

2245
chat_template: llama4
2346
datasets:
24-
- path: mlabonne/FineTome-100k
47+
- path: HuggingFaceH4/llava-instruct-mix-vsft
2548
type: chat_template
26-
split: train[:20%]
27-
field_messages: conversations
28-
message_property_mappings:
29-
role: from
30-
content: value
49+
split: train[:1%]
50+
field_messages: messages
3151

3252
dataset_prepared_path: last_run_prepared
3353
val_set_size: 0.0
3454
output_dir: ./outputs/out
3555

36-
sequence_len: 4096
37-
sample_packing: true
38-
pad_to_sequence_len: true
39-
4056
gradient_accumulation_steps: 1
4157
micro_batch_size: 1
4258
num_epochs: 1
43-
optimizer: adamw_torch_8bit
59+
optimizer: adamw_torch_4bit
4460
lr_scheduler: cosine
4561
learning_rate: 2e-5
4662

4763
bf16: true
4864
tf32: true
4965

50-
# gradient_checkpointing: true
51-
# gradient_checkpointing_kwargs:
52-
# use_reentrant: false
5366
logging_steps: 1
5467
flash_attention: true
5568

5669
warmup_steps: 100
57-
evals_per_epoch: 2
70+
evals_per_epoch: 1
5871
saves_per_epoch: 1
5972
weight_decay: 0.0
6073
fsdp:
6174
- auto_wrap
6275
- full_shard
6376
fsdp_config:
64-
fsdp_version: 2
65-
fsdp_offload_params: false
77+
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
78+
fsdp_limit_all_gathers: true
79+
fsdp_sync_module_states: true
80+
fsdp_offload_params: true
81+
fsdp_use_orig_params: false
6682
fsdp_cpu_ram_efficient_loading: true
6783
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
68-
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
69-
fsdp_state_dict_type: SHARDED_STATE_DICT
84+
fsdp_state_dict_type: FULL_STATE_DICT
7085
fsdp_sharding_strategy: FULL_SHARD
71-
fsdp_reshard_after_forward: true
7286
fsdp_activation_checkpointing: true
7387
special_tokens:
7488
pad_token: <|finetune_right_pad_id|>

0 commit comments

Comments
 (0)