Skip to content

Commit d3f34ee

Browse files
[Shardformer] add assert for num of attention heads divisible by tp_size (#5670)
* add assert for num of attention heads divisible by tp_size * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 6af6d6f commit d3f34ee

File tree

13 files changed

+48
-0
lines changed

13 files changed

+48
-0
lines changed

colossalai/shardformer/policies/bert.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,9 @@ def module_policy(self):
7979
sp_partial_derived = sp_mode == "split_gather"
8080

8181
if self.shard_config.enable_tensor_parallelism:
82+
assert (
83+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
84+
), f"The number of attention heads must be divisible by tensor parallel size."
8285
policy[BertLayer] = ModulePolicyDescription(
8386
attribute_replacement={
8487
"attention.self.all_head_size": self.model.config.hidden_size

colossalai/shardformer/policies/blip2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def module_policy(self):
5252
norm_cls = col_nn.LayerNorm
5353

5454
if self.shard_config.enable_tensor_parallelism:
55+
assert (
56+
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
57+
), f"The number of attention heads must be divisible by tensor parallel size."
5558
policy[Blip2EncoderLayer] = ModulePolicyDescription(
5659
attribute_replacement={
5760
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads

colossalai/shardformer/policies/bloom.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ def module_policy(self):
6161
sp_partial_derived = sp_mode == "split_gather"
6262

6363
if self.shard_config.enable_tensor_parallelism:
64+
assert (
65+
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
66+
), f"The number of attention heads must be divisible by tensor parallel size."
6467
policy[BloomBlock] = ModulePolicyDescription(
6568
attribute_replacement={
6669
"self_attention.hidden_size": self.model.config.hidden_size

colossalai/shardformer/policies/falcon.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@ def module_policy(self):
4747
embedding_cls = col_nn.PaddingEmbedding
4848

4949
if self.shard_config.enable_tensor_parallelism:
50+
assert (
51+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
52+
), f"The number of attention heads must be divisible by tensor parallel size."
53+
assert (
54+
self.model.config.num_kv_heads % self.shard_config.tensor_parallel_size == 0
55+
), f"The number of key_value heads must be divisible by tensor parallel size."
5056
attn_attribute_replacement = {
5157
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
5258
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,

colossalai/shardformer/policies/gpt2.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,9 @@ def module_policy(self):
8484
self.shard_config.enable_flash_attention = False
8585
use_flash_attention = False
8686
if self.shard_config.enable_tensor_parallelism:
87+
assert (
88+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
89+
), f"The number of attention heads must be divisible by tensor parallel size."
8790
policy[GPT2Model] = ModulePolicyDescription(
8891
sub_module_replacement=[
8992
SubModuleReplacementDescription(

colossalai/shardformer/policies/gptj.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def module_policy(self):
5757

5858
overlap = self.shard_config.enable_sequence_overlap
5959
if self.shard_config.enable_tensor_parallelism:
60+
assert (
61+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
62+
), f"The number of attention heads must be divisible by tensor parallel size."
6063
policy[GPTJModel] = ModulePolicyDescription(
6164
sub_module_replacement=[
6265
SubModuleReplacementDescription(

colossalai/shardformer/policies/llama.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
138138
)
139139

140140
if self.shard_config.enable_tensor_parallelism:
141+
assert (
142+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
143+
), f"The number of attention heads must be divisible by tensor parallel size."
144+
assert (
145+
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
146+
), f"The number of key_value heads must be divisible by tensor parallel size."
141147
decoder_attribute_replacement = {
142148
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
143149
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,

colossalai/shardformer/policies/mistral.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
6666
)
6767

6868
if self.shard_config.enable_tensor_parallelism:
69+
assert (
70+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
71+
), f"The number of attention heads must be divisible by tensor parallel size."
72+
assert (
73+
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
74+
), f"The number of key_value heads must be divisible by tensor parallel size."
6975
decoder_attribute_replacement = {
7076
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
7177
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,

colossalai/shardformer/policies/opt.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,9 @@ def module_policy(self):
7676
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")
7777

7878
if self.shard_config.enable_tensor_parallelism:
79+
assert (
80+
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
81+
), f"The number of attention heads must be divisible by tensor parallel size."
7982
policy[OPTDecoderLayer] = ModulePolicyDescription(
8083
sub_module_replacement=[
8184
SubModuleReplacementDescription(

colossalai/shardformer/policies/sam.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,9 @@ def module_policy(self):
3131
norm_cls = col_nn.LayerNorm
3232

3333
if self.shard_config.enable_tensor_parallelism:
34+
assert (
35+
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
36+
), f"The number of attention heads must be divisible by tensor parallel size."
3437
policy[SamVisionLayer] = ModulePolicyDescription(
3538
attribute_replacement={
3639
"attn.num_attention_heads": self.model.config.vision_config.num_attention_heads

0 commit comments

Comments
 (0)