Skip to content

Commit ecd49ce

Browse files
authored
[Fix] Align fused moe lora_b shape with peft (vllm-project#31534)
Signed-off-by: bk-201 <[email protected]>
1 parent e1ee11b commit ecd49ce

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

docs/models/supported_models.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -392,7 +392,7 @@ th {
392392
| `GPTBigCodeForCausalLM` | StarCoder, SantaCoder, WizardCoder | `bigcode/starcoder`, `bigcode/gpt_bigcode-santacoder`, `WizardLM/WizardCoder-15B-V1.0`, etc. | ✅︎ | ✅︎ |
393393
| `GPTJForCausalLM` | GPT-J | `EleutherAI/gpt-j-6b`, `nomic-ai/gpt4all-j`, etc. | | ✅︎ |
394394
| `GPTNeoXForCausalLM` | GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM | `EleutherAI/gpt-neox-20b`, `EleutherAI/pythia-12b`, `OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, `databricks/dolly-v2-12b`, `stabilityai/stablelm-tuned-alpha-7b`, etc. | | ✅︎ |
395-
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | | ✅︎ |
395+
| `GptOssForCausalLM` | GPT-OSS | `openai/gpt-oss-120b`, `openai/gpt-oss-20b` | ✅︎ | ✅︎ |
396396
| `GraniteForCausalLM` | Granite 3.0, Granite 3.1, PowerLM | `ibm-granite/granite-3.0-2b-base`, `ibm-granite/granite-3.1-8b-instruct`, `ibm/PowerLM-3b`, etc. | ✅︎ | ✅︎ |
397397
| `GraniteMoeForCausalLM` | Granite 3.0 MoE, PowerMoE | `ibm-granite/granite-3.0-1b-a400m-base`, `ibm-granite/granite-3.0-3b-a800m-instruct`, `ibm/PowerMoE-3b`, etc. | ✅︎ | ✅︎ |
398398
| `GraniteMoeHybridForCausalLM` | Granite 4.0 MoE Hybrid | `ibm-granite/granite-4.0-tiny-preview`, etc. | ✅︎ | ✅︎ |

tests/lora/test_gptoss_tp.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@
3434
###Response:<|end|><|start|>assistant<|channel|>final<|message|>""" # noqa: E501
3535

3636
EXPECTED_LORA_OUTPUT = [
37-
"SELECT AVG(Working_Horses) FROM farm WHERE Total_Horses > 5000;",
38-
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
39-
"SELECT MAX(Cows) AS Max_Cows, MIN(Cows) AS Min_Cows FROM farm;",
37+
"SELECT avg(Working_Horses) FROM farm WHERE Total_Horses > 5000",
38+
"SELECT max(Cows) , min(Cows) FROM farm",
39+
"SELECT max(Cows) , min(Cows) FROM farm",
4040
]
4141

4242

vllm/lora/layers/fused_moe.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -679,12 +679,12 @@ def set_lora(
679679
# (num_experts,rank,input_size)
680680
w13_lora_a = w13_lora_a.reshape(num_experts, -1, w13_lora_a.shape[-1])
681681
w2_lora_a = w2_lora_a.reshape(num_experts, -1, w2_lora_a.shape[-1])
682-
# (output_size,num_experts,rank)
683-
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], num_experts, -1)
684-
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], num_experts, -1)
682+
# (output_size,rank,num_experts)
683+
w13_lora_b = w13_lora_b.reshape(w13_lora_b.shape[0], -1, num_experts)
684+
w2_lora_b = w2_lora_b.reshape(w2_lora_b.shape[0], -1, num_experts)
685685
# (num_experts,output_size,rank)
686-
w13_lora_b = w13_lora_b.permute(1, 0, 2)
687-
w2_lora_b = w2_lora_b.permute(1, 0, 2)
686+
w13_lora_b = w13_lora_b.permute(2, 0, 1)
687+
w2_lora_b = w2_lora_b.permute(2, 0, 1)
688688

689689
sliced_w13_lora_a = self._slice_w13_a(w13_lora_a)
690690
sliced_w13_lora_b = self._slice_w13_b(w13_lora_b)

0 commit comments

Comments
 (0)