Skip to content

Commit 1178a15

Browse files
authored
Feat: Add qwen3 and CCE for qwen family (axolotl-ai-cloud#2518)
1 parent c513487 commit 1178a15

File tree

10 files changed

+1221
-3
lines changed

10 files changed

+1221
-3
lines changed

examples/qwen3/qlora-fsdp.yaml

Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
base_model: Qwen/Qwen3-8B
2+
# Automatically upload checkpoint and final model to HF
3+
# hub_model_id: username/custom_model_name
4+
5+
load_in_8bit: false
6+
load_in_4bit: true
7+
strict: false
8+
9+
datasets:
10+
- path: tatsu-lab/alpaca
11+
type: alpaca
12+
dataset_prepared_path:
13+
val_set_size: 0.05
14+
output_dir: ./outputs/out
15+
16+
sequence_len: 2048
17+
sample_packing: true
18+
eval_sample_packing: true
19+
pad_to_sequence_len: true
20+
21+
adapter: qlora
22+
lora_model_dir:
23+
lora_r: 32
24+
lora_alpha: 64
25+
lora_dropout: 0.05
26+
lora_target_linear: true
27+
28+
wandb_project:
29+
wandb_entity:
30+
wandb_watch:
31+
wandb_name:
32+
wandb_log_model:
33+
34+
gradient_accumulation_steps: 4
35+
micro_batch_size: 1
36+
num_epochs: 1
37+
optimizer: adamw_torch_fused
38+
lr_scheduler: cosine
39+
learning_rate: 0.0002
40+
41+
bf16: auto
42+
tf32: true
43+
44+
gradient_checkpointing: true
45+
gradient_checkpointing_kwargs:
46+
use_reentrant: false
47+
resume_from_checkpoint:
48+
logging_steps: 1
49+
flash_attention: true
50+
51+
warmup_steps: 10
52+
evals_per_epoch: 4
53+
saves_per_epoch: 1
54+
weight_decay: 0.0
55+
fsdp:
56+
- full_shard
57+
- auto_wrap
58+
fsdp_config:
59+
fsdp_limit_all_gathers: true
60+
fsdp_sync_module_states: true
61+
fsdp_offload_params: true
62+
fsdp_use_orig_params: false
63+
fsdp_cpu_ram_efficient_loading: true
64+
fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP
65+
fsdp_transformer_layer_cls_to_wrap: Qwen3DecoderLayer
66+
fsdp_state_dict_type: FULL_STATE_DICT
67+
fsdp_sharding_strategy: FULL_SHARD
68+
special_tokens:

src/axolotl/integrations/cut_cross_entropy/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,8 @@ plugins:
3232
## Supported Models
3333
3434
- llama
35-
- llama4_text
3635
- llama4
36+
- llama4_text
3737
- mllama
3838
- phi3
3939
- gemma
@@ -43,6 +43,11 @@ plugins:
4343
- mistral
4444
- mistral3
4545
- qwen2
46+
- qwen2_moe
47+
- qwen2_vl
48+
- qwen2_5_vl
49+
- qwen3
50+
- qwen3_moe
4651
- cohere
4752
- cohere2
4853
- glm
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
"""Llama CCE patch. Adapted from transformers v4.51.2"""
2+
3+
# pylint: disable=duplicate-code
4+
5+
6+
from types import MethodType
7+
from typing import Optional, Union
8+
9+
import torch
10+
import transformers
11+
from cut_cross_entropy.transformers.utils import (
12+
PatchOptions,
13+
TransformersModelT,
14+
apply_lce,
15+
)
16+
from transformers.cache_utils import Cache
17+
from transformers.modeling_outputs import (
18+
BaseModelOutputWithPast,
19+
CausalLMOutputWithPast,
20+
)
21+
from transformers.models.llama.modeling_llama import (
22+
_CONFIG_FOR_DOC,
23+
LLAMA_INPUTS_DOCSTRING,
24+
KwargsForCausalLM,
25+
)
26+
from transformers.processing_utils import Unpack
27+
from transformers.utils import (
28+
add_start_docstrings_to_model_forward,
29+
replace_return_docstrings,
30+
)
31+
from transformers.utils.deprecation import deprecate_kwarg
32+
from transformers.utils.generic import can_return_tuple
33+
34+
_PATCH_OPTS: PatchOptions | None = None
35+
36+
37+
@can_return_tuple
38+
@deprecate_kwarg("num_logits_to_keep", version="4.50", new_name="logits_to_keep")
39+
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
40+
@replace_return_docstrings(
41+
output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC
42+
)
43+
def cce_forward(
44+
self,
45+
input_ids: Optional[torch.LongTensor] = None,
46+
attention_mask: Optional[torch.Tensor] = None,
47+
position_ids: Optional[torch.LongTensor] = None,
48+
past_key_values: Optional[Cache] = None,
49+
inputs_embeds: Optional[torch.FloatTensor] = None,
50+
labels: Optional[torch.LongTensor] = None,
51+
use_cache: Optional[bool] = None,
52+
output_attentions: Optional[bool] = None,
53+
output_hidden_states: Optional[bool] = None,
54+
cache_position: Optional[torch.LongTensor] = None,
55+
logits_to_keep: Union[int, torch.Tensor] = 0,
56+
**kwargs: Unpack[KwargsForCausalLM],
57+
) -> CausalLMOutputWithPast:
58+
r"""
59+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
60+
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
61+
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
62+
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
63+
64+
logits_to_keep (`int` or `torch.Tensor`, *optional*):
65+
If an `int`, compute logits for the last `logits_to_keep` tokens. If `0`, calculate logits for all
66+
`input_ids` (special case). Only last token logits are needed for generation, and calculating them only for that
67+
token can save memory, which becomes pretty significant for long sequences or large vocabulary size.
68+
If a `torch.Tensor`, must be 1D corresponding to the indices to keep in the sequence length dimension.
69+
This is useful when using packed tensor format (single dimension for batch and sequence length).
70+
71+
Returns:
72+
73+
Example:
74+
75+
```python
76+
>>> from transformers import AutoTokenizer, LlamaForCausalLM
77+
78+
>>> model = LlamaForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
79+
>>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
80+
81+
>>> prompt = "Hey, are you conscious? Can you talk to me?"
82+
>>> inputs = tokenizer(prompt, return_tensors="pt")
83+
84+
>>> # Generate
85+
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
86+
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
87+
"Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
88+
```"""
89+
output_attentions = (
90+
output_attentions
91+
if output_attentions is not None
92+
else self.config.output_attentions
93+
)
94+
output_hidden_states = (
95+
output_hidden_states
96+
if output_hidden_states is not None
97+
else self.config.output_hidden_states
98+
)
99+
100+
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
101+
outputs: BaseModelOutputWithPast = self.model(
102+
input_ids=input_ids,
103+
attention_mask=attention_mask,
104+
position_ids=position_ids,
105+
past_key_values=past_key_values,
106+
inputs_embeds=inputs_embeds,
107+
use_cache=use_cache,
108+
output_attentions=output_attentions,
109+
output_hidden_states=output_hidden_states,
110+
cache_position=cache_position,
111+
**kwargs,
112+
)
113+
114+
hidden_states = outputs.last_hidden_state
115+
if hidden_states is None:
116+
raise ValueError("hidden_states is None")
117+
118+
loss = None
119+
logits = None
120+
121+
# Only compute necessary logits, and do not upcast them to float if we are not computing the loss
122+
slice_indices = (
123+
slice(-logits_to_keep, None)
124+
if isinstance(logits_to_keep, int)
125+
else logits_to_keep
126+
)
127+
if _PATCH_OPTS is not None and _PATCH_OPTS.use_lce(labels, self.training):
128+
assert labels is not None
129+
loss = apply_lce(
130+
hidden_states[:, slice_indices, :],
131+
self.lm_head.weight,
132+
labels,
133+
_PATCH_OPTS,
134+
**kwargs,
135+
)
136+
else:
137+
logits = self.lm_head(hidden_states[:, slice_indices, :])
138+
139+
if labels is not None:
140+
loss = self.loss_function(
141+
logits=logits,
142+
labels=labels,
143+
vocab_size=self.config.vocab_size,
144+
**kwargs,
145+
)
146+
147+
return CausalLMOutputWithPast(
148+
loss=loss,
149+
logits=logits,
150+
past_key_values=outputs.past_key_values,
151+
hidden_states=outputs.hidden_states,
152+
attentions=outputs.attentions,
153+
)
154+
155+
156+
def patch_llama(
157+
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
158+
patch_options: PatchOptions,
159+
) -> TransformersModelT | None:
160+
"""Patch Llama for CCE."""
161+
global _PATCH_OPTS # pylint: disable=global-statement
162+
from transformers.models.llama import modeling_llama
163+
164+
_PATCH_OPTS = patch_options
165+
166+
if isinstance(maybe_model, transformers.PreTrainedModel):
167+
assert isinstance(
168+
maybe_model, modeling_llama.LlamaForCausalLM
169+
), f"Expected a LlamaForCausalLM model. Got {type(maybe_model)}."
170+
maybe_model.forward = MethodType(cce_forward, maybe_model)
171+
return maybe_model
172+
173+
modeling_llama.LlamaForCausalLM.forward = cce_forward
174+
return None

src/axolotl/integrations/cut_cross_entropy/monkeypatch/patch.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
import transformers
66
from cut_cross_entropy.cce_utils import LinearCrossEntropyImpl
77
from cut_cross_entropy.linear_cross_entropy import LCE_IMPL_DEFAULT
8-
from cut_cross_entropy.transformers.llama import patch_llama
98
from cut_cross_entropy.transformers.phi3 import patch_phi3
10-
from cut_cross_entropy.transformers.qwen2 import patch_qwen2
119
from cut_cross_entropy.transformers.utils import PatchOptions, TransformersModelT
1210

1311
from axolotl.integrations.cut_cross_entropy.monkeypatch.cohere import (
@@ -24,6 +22,9 @@
2422
patch_glm,
2523
patch_glm4,
2624
)
25+
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
26+
patch_llama,
27+
)
2728
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama4 import (
2829
patch_llama4,
2930
patch_llama4_text,
@@ -33,6 +34,22 @@
3334
patch_mistral3,
3435
)
3536
from axolotl.integrations.cut_cross_entropy.monkeypatch.mllama import patch_mllama
37+
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2 import (
38+
patch_qwen2,
39+
)
40+
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_5_vl import (
41+
patch_qwen2_5_vl,
42+
)
43+
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_moe import (
44+
patch_qwen2_moe,
45+
)
46+
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen2_vl import (
47+
patch_qwen2_vl,
48+
)
49+
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3 import patch_qwen3
50+
from axolotl.integrations.cut_cross_entropy.monkeypatch.qwen3_moe import (
51+
patch_qwen3_moe,
52+
)
3653

3754
CUT_CROSS_ENTROPY_MODEL_MAPPING = {
3855
"llama": patch_llama,
@@ -47,6 +64,11 @@
4764
"mistral": patch_mistral,
4865
"mistral3": patch_mistral3,
4966
"qwen2": patch_qwen2,
67+
"qwen2_moe": patch_qwen2_moe,
68+
"qwen2_vl": patch_qwen2_vl,
69+
"qwen2_5_vl": patch_qwen2_5_vl,
70+
"qwen3": patch_qwen3,
71+
"qwen3_moe": patch_qwen3_moe,
5072
"cohere": patch_cohere,
5173
"cohere2": patch_cohere2,
5274
"glm": patch_glm,
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
"""Qwen2 CCE patch. The model inherits Llama's modeling code and uses the same forward method."""
2+
3+
# pylint: disable=duplicate-code
4+
5+
from types import MethodType
6+
7+
import transformers
8+
from cut_cross_entropy.transformers.utils import (
9+
PatchOptions,
10+
TransformersModelT,
11+
)
12+
13+
14+
def patch_qwen2(
15+
maybe_model: TransformersModelT | str | transformers.PretrainedConfig,
16+
patch_options: PatchOptions,
17+
) -> TransformersModelT | None:
18+
from transformers.models.qwen2 import modeling_qwen2
19+
20+
# Set the _PATCH_OPTS in the llama patch file
21+
import axolotl.integrations.cut_cross_entropy.monkeypatch.llama as llama_patch
22+
23+
llama_patch._PATCH_OPTS = patch_options # pylint: disable=protected-access
24+
25+
from axolotl.integrations.cut_cross_entropy.monkeypatch.llama import (
26+
cce_forward,
27+
)
28+
29+
if isinstance(maybe_model, transformers.PreTrainedModel):
30+
assert isinstance(
31+
maybe_model, modeling_qwen2.Qwen2ForCausalLM
32+
), f"Expected a Qwen2ForCausalLM model. Got {type(maybe_model)}."
33+
maybe_model.forward = MethodType(cce_forward, maybe_model)
34+
return maybe_model
35+
36+
modeling_qwen2.Qwen2ForCausalLM.forward = cce_forward
37+
return None

0 commit comments

Comments
 (0)