Skip to content

Commit 537f6a3

Browse files
[Shardformer]fix the num_heads assert for llama model and qwen model (#5704)
* fix the num_heads assert * fix the transformers import * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the import --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent a3cc68c commit 537f6a3

File tree

3 files changed

+37
-30
lines changed

3 files changed

+37
-30
lines changed

colossalai/shardformer/modeling/qwen2.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,20 @@
1010

1111
try:
1212
from transformers.models.qwen2.modeling_qwen2 import (
13+
Qwen2Attention,
1314
Qwen2ForCausalLM,
1415
Qwen2ForSequenceClassification,
1516
Qwen2Model,
1617
_prepare_4d_causal_attention_mask,
1718
_prepare_4d_causal_attention_mask_for_sdpa,
19+
apply_rotary_pos_emb,
20+
repeat_kv,
1821
)
1922
except ImportError:
2023
Qwen2Model = "Qwen2Model"
21-
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
2224
Qwen2ForCausalLM = "Qwen2ForCausalLM"
25+
Qwen2Attention = "Qwen2Attention"
26+
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
2327

2428
from transformers.utils import logging
2529

@@ -451,10 +455,6 @@ def qwen2_for_sequence_classification_forward(
451455

452456

453457
def get_qwen2_flash_attention_forward(shard_config: ShardConfig):
454-
from transformers.models.qwen2.modeling_qwen2 import Qwen2Attention, apply_rotary_pos_emb, repeat_kv
455-
456-
from colossalai.shardformer.layer import ColoAttention
457-
458458
def forward(
459459
self: Qwen2Attention,
460460
hidden_states: torch.Tensor,

colossalai/shardformer/policies/llama.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -141,9 +141,11 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
141141
assert (
142142
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
143143
), f"The number of attention heads must be divisible by tensor parallel size."
144-
assert (
145-
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
146-
), f"The number of key_value heads must be divisible by tensor parallel size."
144+
if hasattr(self.model.config, "num_key_value_heads"):
145+
assert (
146+
self.model.config.num_key_value_heads >= self.shard_config.tensor_parallel_size
147+
and self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
148+
), f"The number of key_value heads must be divisible by, and must not be less than tensor parallel size."
147149
decoder_attribute_replacement = {
148150
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
149151
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,

colossalai/shardformer/policies/qwen2.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,26 @@
2121
get_qwen2_flash_attention_forward,
2222
get_qwen2_model_forward_for_flash_attn,
2323
)
24+
25+
try:
26+
from transformers.models.qwen2.modeling_qwen2 import (
27+
Qwen2Attention,
28+
Qwen2DecoderLayer,
29+
Qwen2FlashAttention2,
30+
Qwen2ForCausalLM,
31+
Qwen2ForSequenceClassification,
32+
Qwen2Model,
33+
Qwen2SdpaAttention,
34+
)
35+
except ImportError:
36+
Qwen2ForCausalLM = "Qwen2ForCausalLM"
37+
Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
38+
Qwen2Attention = "Qwen2Attention"
39+
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
40+
Qwen2SdpaAttention = "Qwen2SdpaAttention"
41+
Qwen2DecoderLayer = "Qwen2DecoderLayer"
42+
Qwen2Model = "Qwen2Model"
43+
2444
from .base_policy import ModulePolicyDescription, Policy, SubModuleReplacementDescription
2545

2646
__all__ = ["Qwen2Policy", "Qwen2ForCausalLMPolicy", "Qwen2ForSequenceClassificationPolicy"]
@@ -45,21 +65,6 @@ def preprocess(self):
4565
return self.model
4666

4767
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
48-
try:
49-
from transformers.models.qwen2.modeling_qwen2 import (
50-
Qwen2Attention,
51-
Qwen2DecoderLayer,
52-
Qwen2FlashAttention2,
53-
Qwen2Model,
54-
Qwen2SdpaAttention,
55-
)
56-
except ImportError:
57-
Qwen2Attention = "Qwen2Attention"
58-
Qwen2FlashAttention2 = "Qwen2FlashAttention2"
59-
Qwen2SdpaAttention = "Qwen2SdpaAttention"
60-
Qwen2DecoderLayer = "Qwen2DecoderLayer"
61-
Qwen2Model = "Qwen2Model"
62-
6368
ATTN_IMPLEMENTATION = {
6469
"eager": Qwen2Attention,
6570
"flash_attention_2": Qwen2FlashAttention2,
@@ -82,6 +87,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
8287
warnings.warn("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
8388

8489
if self.shard_config.enable_tensor_parallelism:
90+
assert (
91+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
92+
), f"The number of attention heads must be divisible by tensor parallel size."
93+
if hasattr(self.model.config, "num_key_value_heads"):
94+
assert (
95+
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
96+
), f"The number of key_value heads must be divisible by tensor parallel size."
8597
decoder_attribute_replacement = {
8698
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
8799
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
@@ -256,7 +268,6 @@ def get_held_layers(self) -> List[Module]:
256268
class Qwen2ModelPolicy(Qwen2Policy):
257269
def module_policy(self):
258270
policy = super().module_policy()
259-
from transformers.models.qwen2.modeling_qwen2 import Qwen2Model
260271

261272
if self.pipeline_stage_manager:
262273
# set None as default
@@ -277,10 +288,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
277288

278289
class Qwen2ForCausalLMPolicy(Qwen2Policy):
279290
def module_policy(self):
280-
from transformers import Qwen2ForCausalLM
281-
282291
policy = super().module_policy()
283-
284292
setattr(self.shard_config, "causal_lm", True)
285293

286294
if self.shard_config.enable_tensor_parallelism:
@@ -330,10 +338,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
330338

331339
class Qwen2ForSequenceClassificationPolicy(Qwen2Policy):
332340
def module_policy(self):
333-
from transformers import Qwen2ForSequenceClassification
334-
335341
policy = super().module_policy()
336-
337342
if self.shard_config.enable_tensor_parallelism:
338343
# add a new item for sequence classification
339344
new_item = {

0 commit comments

Comments
 (0)