Skip to content

Commit 682a9cf

Browse files
NanoCode012winglianSalmanMohammadi
authored
Fix: add delinearization and make qlora work with fsdp2 (axolotl-ai-cloud#2515)
* fixes for delinearization, and make qlora work with fsdp2 * Add back mistakenly removed lm_eval * typo [skip ci] * patch evals for torch.compile + fsdp2 * also check torch_compile w fsdp2 * lots of fixes for flex attn with llama4 * fix patch check and patch llama4 too * attempt to make the patches stick * use transformers 4.51.2 * update configs and README for llama4 * remove torch.compile for CI test * cleanup any existing singletons * set singleton cache to None instead of deleting * use importlib reload with monkeypatch * don't worry about transformers version, mark inputs with grads, fix regex * make sure embeds aren't on cpu * logging and mem improvements * vllm version and add to docker, make sure to save processor on conversion * fix ambiguous tensor bool check * fix vllm to not use v1, upgrade hf transformers * fix tests * make flex_attn_compile_kwargs configurable, since this depends on model params --------- Co-authored-by: Wing Lian <[email protected]> Co-authored-by: Salman Mohammadi <[email protected]>
1 parent 271b24c commit 682a9cf

26 files changed

+629
-45
lines changed

.github/workflows/main.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ jobs:
2929
cuda_version: 12.4.1
3030
python_version: "3.11"
3131
pytorch: 2.6.0
32-
axolotl_extras:
32+
axolotl_extras: vllm
3333
is_latest: true
3434
runs-on: axolotl-gpu-runner
3535
steps:

examples/llama-4/README.md

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,28 @@
11
# Llama 4 by Meta AI
22

3+
## Flash Attention vs Flex Attention
4+
5+
While Flash Attention to support is "enabled" for Llama-4, the upstream implementation is not correct and usage of Flex Attention is recommended.
6+
37
## Available Examples
48

59
### Llama 4 Scout 17Bx16Experts (109B)
6-
- [Multi-Modal/Vision QLoRA w/ FSDP1](./scout-vision-qlora-fsdp.yaml)
7-
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100.yaml)
8-
- [Text Multi GPU QLoRA w/ FSDP1](./scout-qlora-fsdp1.yaml)
910

10-
Our Single H100 implementation for Llama 4 Scout uses only 68.5GB VRAM for post-training with 4k context length @ 546 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-sft/runs/zic56rhd)
11+
Flex Attention
12+
- [Text Single GPU (H100) QLoRA](./scout-qlora-single-h100-flex.yaml)
13+
- [Text Multi GPU QLoRA w/ FSDP2](./scout-qlora-flexattn-fsdp2.yaml)
1114

12-
### Llama 4 Maverick 17Bx128Experts (400B)
15+
[//]: # (Flash Attention &#40;Do not use&#41;)
16+
17+
[//]: # (- [Multi-Modal/Vision QLoRA w/ FSDP1]&#40;./scout-vision-qlora-fsdp.yaml&#41;)
1318

14-
- [Text Multi GPU QLoRA w/FSDP1](./maverick-qlora-fsdp1.yaml)
19+
[//]: # (- [Text Single GPU &#40;H100&#41; QLoRA]&#40;./scout-qlora-single-h100.yaml&#41;)
20+
21+
[//]: # (- [Text Multi GPU QLoRA w/ FSDP1]&#40;./scout-qlora-fsdp1.yaml&#41;)
22+
23+
Our Single H100 implementation for Llama 4 Scout uses only 64.5GB VRAM for post-training with 4k context length @ 519 tokens/second. [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/wpie7dkj)
24+
Multi-GPU (4xH100) for Llama 4 Scout uses 62.8GB VRAM/GPU @ 4k contenxt length @ 280tps/gpu, [WandB logs here](https://wandb.ai/axolotl-ai/llama4-flexattn-qlora/runs/2lkezdj8)
25+
26+
### Llama 4 Maverick 17Bx128Experts (400B)
1527

16-
Our 4xH100 implementation for Llama 4 Maverick uses 79.5GB VRAM/GPU for post-training with 4k context length @ 206 tokens/second. [WandB logs here.](https://wandb.ai/axolotl-ai/llama-sft/runs/siyvwuxc?nw=nwuserwinglian)
28+
Coming Soon
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
2+
model_type: Llama4ForConditionalGeneration
3+
# Automatically upload checkpoint and final model to HF
4+
# hub_model_id: username/custom_model_name
5+
6+
plugins:
7+
- axolotl.integrations.liger.LigerPlugin
8+
9+
liger_glu_activation: true
10+
liger_rms_norm: true
11+
liger_layer_norm: true
12+
13+
llama4_linearized_experts: true
14+
load_in_4bit: true
15+
adapter: qlora
16+
lora_r: 32
17+
lora_alpha: 64
18+
lora_target_modules:
19+
- self_attn.q_proj
20+
- self_attn.k_proj
21+
- self_attn.v_proj
22+
- self_attn.o_proj
23+
- shared_expert.gate_proj
24+
- shared_expert.up_proj
25+
- shared_expert.down_proj
26+
# - experts.gate_projs.[0-9]+$
27+
# - experts.up_projs.[0-9]+$
28+
# - experts.down_projs.[0-9]+$
29+
lora_modules_to_save:
30+
# - lm_head
31+
# - embed_tokens
32+
33+
chat_template: llama4
34+
datasets:
35+
- path: mlabonne/FineTome-100k
36+
type: chat_template
37+
split: train[:20%]
38+
field_messages: conversations
39+
message_property_mappings:
40+
role: from
41+
content: value
42+
43+
dataset_prepared_path: last_run_prepared
44+
val_set_size: 0.0
45+
output_dir: ./outputs/out
46+
47+
sequence_len: 4096
48+
sample_packing: true
49+
pad_to_sequence_len: true
50+
51+
gradient_accumulation_steps: 1
52+
micro_batch_size: 2
53+
num_epochs: 3
54+
optimizer: adamw_torch_4bit
55+
lr_scheduler: cosine
56+
learning_rate: 1e-4
57+
58+
bf16: true
59+
tf32: true
60+
61+
logging_steps: 1
62+
flex_attention: true
63+
flex_attn_compile_kwargs:
64+
dynamic: false
65+
mode: max-autotune-no-cudagraphs
66+
67+
warmup_steps: 10
68+
evals_per_epoch: 1
69+
saves_per_epoch: 1
70+
weight_decay: 0.0
71+
fsdp:
72+
- auto_wrap
73+
- full_shard
74+
fsdp_config:
75+
fsdp_version: 2
76+
fsdp_offload_params: false
77+
fsdp_cpu_ram_efficient_loading: true
78+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
79+
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
80+
fsdp_state_dict_type: SHARDED_STATE_DICT
81+
fsdp_sharding_strategy: FULL_SHARD
82+
fsdp_reshard_after_forward: true
83+
fsdp_activation_checkpointing: true
84+
special_tokens:
85+
pad_token: <|finetune_right_pad_id|>
86+
eos_token: <|eot|>
Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
2+
model_type: Llama4ForConditionalGeneration
3+
# Automatically upload checkpoint and final model to HF
4+
# hub_model_id: username/custom_model_name
5+
6+
plugins:
7+
- axolotl.integrations.liger.LigerPlugin
8+
- axolotl.integrations.cut_cross_entropy.CutCrossEntropyPlugin
9+
10+
liger_glu_activation: true
11+
liger_rms_norm: true
12+
liger_layer_norm: true
13+
cut_cross_entropy: true
14+
15+
llama4_linearized_experts: true # needed with custom linearized experts model
16+
load_in_4bit: true
17+
adapter: qlora
18+
lora_r: 32
19+
lora_alpha: 64
20+
lora_target_modules:
21+
- self_attn.q_proj
22+
- self_attn.k_proj
23+
- self_attn.v_proj
24+
- self_attn.o_proj
25+
- shared_expert.gate_proj
26+
- shared_expert.up_proj
27+
- shared_expert.down_proj
28+
# - experts.gate_projs.[0-9]+$ # optionally train the moe experts
29+
# - experts.up_projs.[0-9]+$
30+
# - experts.down_projs.[0-9]+$
31+
lora_modules_to_save:
32+
# - lm_head # needed if modifying vocabulary
33+
# - embed_tokens
34+
35+
lora_mlp_kernel: true
36+
lora_qkv_kernel: true
37+
lora_o_kernel: true
38+
39+
chat_template: llama4
40+
datasets:
41+
- path: mlabonne/FineTome-100k
42+
type: chat_template
43+
split: train[:20%]
44+
field_messages: conversations
45+
message_property_mappings:
46+
role: from
47+
content: value
48+
49+
dataset_prepared_path: last_run_prepared
50+
val_set_size: 0.0
51+
output_dir: ./outputs/out
52+
53+
sequence_len: 4096 # up to 8k will work on a single H100
54+
sample_packing: true
55+
pad_to_sequence_len: true
56+
57+
gradient_accumulation_steps: 1
58+
micro_batch_size: 1
59+
num_epochs: 1
60+
optimizer: adamw_torch_4bit
61+
lr_scheduler: cosine
62+
learning_rate: 1e-4
63+
64+
bf16: true
65+
tf32: true
66+
67+
torch_compile: true
68+
flex_attention: true
69+
flex_attn_compile_kwargs:
70+
dynamic: false
71+
mode: max-autotune-no-cudagraphs
72+
73+
gradient_checkpointing: offload
74+
gradient_checkpointing_kwargs:
75+
use_reentrant: false
76+
77+
logging_steps: 1
78+
warmup_steps: 20
79+
evals_per_epoch: 1
80+
saves_per_epoch: 1
81+
82+
weight_decay: 0.0
83+
special_tokens:
84+
pad_token: <|finetune_right_pad_id|>
85+
eos_token: <|eot|>
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
base_model: axolotl-quants/Llama-4-Scout-17B-16E-Linearized-bnb-nf4-bf16
2+
model_type: Llama4ForConditionalGeneration
3+
processor_type: Llama4Processor
4+
# Automatically upload checkpoint and final model to HF
5+
# hub_model_id: username/custom_model_name
6+
7+
# these 3 lines are needed for now to handle vision chat templates w images
8+
skip_prepare_dataset: true
9+
remove_unused_columns: false
10+
sample_packing: false
11+
12+
sequence_len: 4096
13+
14+
plugins:
15+
- axolotl.integrations.liger.LigerPlugin
16+
17+
liger_glu_activation: true
18+
liger_rms_norm: true
19+
liger_layer_norm: true
20+
21+
llama4_linearized_experts: true # use Axolotl's customized model
22+
load_in_4bit: true
23+
adapter: qlora
24+
lora_r: 32
25+
lora_alpha: 64
26+
lora_target_modules:
27+
- self_attn.q_proj
28+
- self_attn.k_proj
29+
- self_attn.v_proj
30+
- self_attn.o_proj
31+
- shared_expert.gate_proj
32+
- shared_expert.up_proj
33+
- shared_expert.down_proj
34+
- vision_adapter.mlp.fc1
35+
- vision_adapter.mlp.fc2
36+
# - experts.gate_projs.[0-9]+$
37+
# - experts.up_projs.[0-9]+$
38+
# - experts.down_projs.[0-9]+$
39+
lora_modules_to_save:
40+
- lm_head
41+
- embed_tokens
42+
43+
chat_template: llama4
44+
datasets:
45+
- path: HuggingFaceH4/llava-instruct-mix-vsft
46+
type: chat_template
47+
split: train[:1%]
48+
field_messages: messages
49+
50+
dataset_prepared_path: last_run_prepared
51+
val_set_size: 0.0
52+
output_dir: ./outputs/out
53+
54+
gradient_accumulation_steps: 1
55+
micro_batch_size: 1
56+
num_epochs: 1
57+
optimizer: adamw_torch_4bit
58+
lr_scheduler: cosine
59+
learning_rate: 1e-4
60+
61+
bf16: true
62+
tf32: true
63+
64+
logging_steps: 1
65+
flex_attention: true
66+
flex_attn_compile_kwargs:
67+
dynamic: false
68+
mode: max-autotune-no-cudagraphs
69+
70+
warmup_steps: 10
71+
evals_per_epoch: 1
72+
saves_per_epoch: 1
73+
weight_decay: 0.0
74+
fsdp:
75+
- auto_wrap
76+
- full_shard
77+
fsdp_config:
78+
fsdp_version: 2
79+
fsdp_offload_params: false
80+
fsdp_cpu_ram_efficient_loading: true
81+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
82+
fsdp_transformer_layer_cls_to_wrap: Llama4TextDecoderLayer
83+
fsdp_state_dict_type: SHARDED_STATE_DICT
84+
fsdp_sharding_strategy: FULL_SHARD
85+
fsdp_reshard_after_forward: true
86+
fsdp_activation_checkpointing: true
87+
special_tokens:
88+
pad_token: <|finetune_right_pad_id|>
89+
eos_token: <|eot|>

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ liger-kernel==0.5.6
1212
packaging==23.2
1313

1414
peft==0.15.1
15-
transformers==4.51.1
15+
transformers==4.51.3
1616
tokenizers>=0.21.1
1717
accelerate==1.6.0
1818
datasets==3.5.0

0 commit comments

Comments
 (0)