Skip to content

Commit 07e4f2e

Browse files
authored
support for qwen3 with lora kernels (axolotl-ai-cloud#2588)
* support for qwen3 with lora kernels * fix patch * typo
1 parent c7d07de commit 07e4f2e

File tree

2 files changed

+56
-19
lines changed

2 files changed

+56
-19
lines changed

src/axolotl/monkeypatch/lora_kernels.py

Lines changed: 41 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,22 +23,42 @@
2323

2424
LOG = get_logger(__name__)
2525

26-
ORIGINAL_QKV_CODE = """
26+
QKV_PATCHES = [
27+
(
28+
"""
2729
query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
2830
key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
2931
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
3032
""".lstrip(
31-
"\n"
32-
)
33-
34-
PATCHED_QKV_CODE = """
33+
"\n"
34+
),
35+
"""
3536
query_states, key_states, value_states = self.apply_qkv(hidden_states)
3637
query_states = query_states.view(hidden_shape).transpose(1, 2)
3738
key_states = key_states.view(hidden_shape).transpose(1, 2)
3839
value_states = value_states.view(hidden_shape).transpose(1, 2)
3940
""".lstrip(
40-
"\n"
41-
)
41+
"\n"
42+
),
43+
),
44+
(
45+
"""
46+
query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
47+
key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
48+
value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
49+
""".lstrip(
50+
"\n"
51+
),
52+
"""
53+
query_states, key_states, value_states = self.apply_qkv(hidden_states)
54+
query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2)
55+
key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2)
56+
value_states = value_states.view(hidden_shape).transpose(1, 2)
57+
""".lstrip(
58+
"\n"
59+
),
60+
),
61+
]
4262

4363
ORIGINAL_O_CODE = """
4464
attn_output = self.o_proj(attn_output)
@@ -128,10 +148,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
128148
try:
129149
# Dynamically import the module and attention class
130150
module_path = f"transformers.models.{model_type}.modeling_{model_type}"
131-
module = __import__(
132-
module_path, fromlist=[f"{model_type.capitalize()}Attention"]
151+
model_cls_prefix = "".join(
152+
[part.capitalize() for part in model_type.split("_")]
133153
)
134-
attention_cls = getattr(module, f"{model_type.capitalize()}Attention")
154+
module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"])
155+
attention_cls = getattr(module, f"{model_cls_prefix}Attention")
135156

136157
return attention_cls
137158
except (ImportError, AttributeError) as e:
@@ -168,10 +189,18 @@ def patch_self_attn_lora(cfg: DictDefault):
168189
attention_cls._original_forward = self_attn_forward
169190
self_attn_forward, _ = detab_code(self_attn_forward)
170191

171-
assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found"
192+
assert any(
193+
qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES
194+
), "Original QKV code not found"
172195
assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
173196

174-
self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE)
197+
for qkv_orig, qkv_patched in QKV_PATCHES:
198+
if qkv_orig in self_attn_forward:
199+
self_attn_forward = self_attn_forward.replace(
200+
qkv_orig,
201+
qkv_patched,
202+
)
203+
break
175204
self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
176205
self_attn_forward = self_attn_forward.replace(
177206
"def forward(",

tests/e2e/patched/lora_kernels/test_lora_kernel_patching.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from transformers import AutoModelForCausalLM, LlamaForCausalLM
1010
from transformers.models.llama.configuration_llama import LlamaConfig
1111
from transformers.models.llama.modeling_llama import LlamaAttention
12+
from transformers.models.qwen3_moe.modeling_qwen3_moe import Qwen3MoeAttention
1213

1314
from axolotl.kernels.lora import (
1415
apply_lora_mlp_geglu,
@@ -66,29 +67,36 @@ def small_llama_model():
6667
return LlamaForCausalLM(LlamaConfig(**config))
6768

6869

69-
def test_attention_patching_integration():
70+
@pytest.mark.parametrize(
71+
"model_name,attention_cls",
72+
[
73+
("HuggingFaceTB/SmolLM2-135M", LlamaAttention),
74+
("Qwen/Qwen3-30B-A3B", Qwen3MoeAttention),
75+
],
76+
)
77+
def test_attention_patching_integration(model_name, attention_cls):
7078
"""Test attention patching in integration context."""
71-
cfg = {"base_model": "TinyLlama/TinyLlama-1.1B-Chat-v1.0"}
79+
cfg = {"base_model": model_name}
7280

7381
# Store the original implementation
74-
original_forward = getattr(LlamaAttention, "forward")
82+
original_forward = getattr(attention_cls, "forward")
7583

7684
# Apply patch
7785
patch_self_attn_lora(cfg)
7886

7987
# Get the new forward method
80-
patched_forward = LlamaAttention.forward
88+
patched_forward = attention_cls.forward
8189

8290
# Check the forward method was replaced
8391
assert original_forward is not patched_forward
8492
assert patched_forward.__name__ == "axolotl_attn_forward"
8593

8694
# Check original implementation was stored
87-
assert hasattr(LlamaAttention, "_original_forward")
95+
assert hasattr(attention_cls, "_original_forward")
8896

8997
# Clean up
90-
setattr(LlamaAttention, "forward", original_forward)
91-
delattr(LlamaAttention, "_original_forward")
98+
setattr(attention_cls, "forward", original_forward)
99+
delattr(attention_cls, "_original_forward")
92100

93101

94102
def test_swiglu_mlp_integration(small_llama_model):

0 commit comments

Comments
 (0)