Skip to content

Commit 131eece

Browse files
author
Tong Li
committed
fix tp bug
1 parent 704866a commit 131eece

File tree

1 file changed

+11
-1
lines changed

1 file changed

+11
-1
lines changed

colossalai/shardformer/policies/qwen2.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
PaddingEmbedding,
1414
RMSNorm,
1515
VocabParallelEmbedding1D,
16+
VocabParallelLMHead1D,
1617
)
1718

1819
from ..modeling.qwen2 import (
@@ -446,7 +447,16 @@ def module_policy(self):
446447
suffix="lm_head",
447448
target_module=LinearWithGradAccum,
448449
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
449-
)
450+
),
451+
SubModuleReplacementDescription(
452+
suffix="lm_head",
453+
target_module=VocabParallelLMHead1D,
454+
kwargs={
455+
"gather_output": not self.shard_config.parallel_output,
456+
"make_vocab_size_divisible_by": self.shard_config.make_vocab_size_divisible_by,
457+
"fp8_communication": self.shard_config.fp8_communication,
458+
},
459+
),
450460
],
451461
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
452462
)

0 commit comments

Comments
 (0)