21
21
get_qwen2_flash_attention_forward ,
22
22
get_qwen2_model_forward_for_flash_attn ,
23
23
)
24
+
25
+ try :
26
+ from transformers .models .qwen2 .modeling_qwen2 import (
27
+ Qwen2Attention ,
28
+ Qwen2DecoderLayer ,
29
+ Qwen2FlashAttention2 ,
30
+ Qwen2ForCausalLM ,
31
+ Qwen2ForSequenceClassification ,
32
+ Qwen2Model ,
33
+ Qwen2SdpaAttention ,
34
+ )
35
+ except ImportError :
36
+ Qwen2ForCausalLM = "Qwen2ForCausalLM"
37
+ Qwen2ForSequenceClassification = "Qwen2ForSequenceClassification"
38
+ Qwen2Attention = "Qwen2Attention"
39
+ Qwen2FlashAttention2 = "Qwen2FlashAttention2"
40
+ Qwen2SdpaAttention = "Qwen2SdpaAttention"
41
+ Qwen2DecoderLayer = "Qwen2DecoderLayer"
42
+ Qwen2Model = "Qwen2Model"
43
+
24
44
from .base_policy import ModulePolicyDescription , Policy , SubModuleReplacementDescription
25
45
26
46
__all__ = ["Qwen2Policy" , "Qwen2ForCausalLMPolicy" , "Qwen2ForSequenceClassificationPolicy" ]
@@ -45,21 +65,6 @@ def preprocess(self):
45
65
return self .model
46
66
47
67
def module_policy (self ) -> Dict [Union [str , nn .Module ], ModulePolicyDescription ]:
48
- try :
49
- from transformers .models .qwen2 .modeling_qwen2 import (
50
- Qwen2Attention ,
51
- Qwen2DecoderLayer ,
52
- Qwen2FlashAttention2 ,
53
- Qwen2Model ,
54
- Qwen2SdpaAttention ,
55
- )
56
- except ImportError :
57
- Qwen2Attention = "Qwen2Attention"
58
- Qwen2FlashAttention2 = "Qwen2FlashAttention2"
59
- Qwen2SdpaAttention = "Qwen2SdpaAttention"
60
- Qwen2DecoderLayer = "Qwen2DecoderLayer"
61
- Qwen2Model = "Qwen2Model"
62
-
63
68
ATTN_IMPLEMENTATION = {
64
69
"eager" : Qwen2Attention ,
65
70
"flash_attention_2" : Qwen2FlashAttention2 ,
@@ -82,6 +87,13 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
82
87
warnings .warn ("Qwen2 doesn't support sequence parallelism now, will ignore the sequence parallelism flag." )
83
88
84
89
if self .shard_config .enable_tensor_parallelism :
90
+ assert (
91
+ self .model .config .num_attention_heads % self .shard_config .tensor_parallel_size == 0
92
+ ), f"The number of attention heads must be divisible by tensor parallel size."
93
+ if hasattr (self .model .config , "num_key_value_heads" ):
94
+ assert (
95
+ self .model .config .num_key_value_heads % self .shard_config .tensor_parallel_size == 0
96
+ ), f"The number of key_value heads must be divisible by tensor parallel size."
85
97
decoder_attribute_replacement = {
86
98
"self_attn.hidden_size" : self .model .config .hidden_size // self .shard_config .tensor_parallel_size ,
87
99
"self_attn.num_heads" : self .model .config .num_attention_heads // self .shard_config .tensor_parallel_size ,
@@ -256,7 +268,6 @@ def get_held_layers(self) -> List[Module]:
256
268
class Qwen2ModelPolicy (Qwen2Policy ):
257
269
def module_policy (self ):
258
270
policy = super ().module_policy ()
259
- from transformers .models .qwen2 .modeling_qwen2 import Qwen2Model
260
271
261
272
if self .pipeline_stage_manager :
262
273
# set None as default
@@ -277,10 +288,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
277
288
278
289
class Qwen2ForCausalLMPolicy (Qwen2Policy ):
279
290
def module_policy (self ):
280
- from transformers import Qwen2ForCausalLM
281
-
282
291
policy = super ().module_policy ()
283
-
284
292
setattr (self .shard_config , "causal_lm" , True )
285
293
286
294
if self .shard_config .enable_tensor_parallelism :
@@ -330,10 +338,7 @@ def get_shared_params(self) -> List[Dict[int, Tensor]]:
330
338
331
339
class Qwen2ForSequenceClassificationPolicy (Qwen2Policy ):
332
340
def module_policy (self ):
333
- from transformers import Qwen2ForSequenceClassification
334
-
335
341
policy = super ().module_policy ()
336
-
337
342
if self .shard_config .enable_tensor_parallelism :
338
343
# add a new item for sequence classification
339
344
new_item = {
0 commit comments