File tree Expand file tree Collapse file tree 1 file changed +11
-1
lines changed
colossalai/shardformer/policies Expand file tree Collapse file tree 1 file changed +11
-1
lines changed Original file line number Diff line number Diff line change 13
13
PaddingEmbedding ,
14
14
RMSNorm ,
15
15
VocabParallelEmbedding1D ,
16
+ VocabParallelLMHead1D ,
16
17
)
17
18
18
19
from ..modeling .qwen2 import (
@@ -446,7 +447,16 @@ def module_policy(self):
446
447
suffix = "lm_head" ,
447
448
target_module = LinearWithGradAccum ,
448
449
kwargs = dict (fp8_communication = self .shard_config .fp8_communication , use_zbv = use_zbv ),
449
- )
450
+ ),
451
+ SubModuleReplacementDescription (
452
+ suffix = "lm_head" ,
453
+ target_module = VocabParallelLMHead1D ,
454
+ kwargs = {
455
+ "gather_output" : not self .shard_config .parallel_output ,
456
+ "make_vocab_size_divisible_by" : self .shard_config .make_vocab_size_divisible_by ,
457
+ "fp8_communication" : self .shard_config .fp8_communication ,
458
+ },
459
+ ),
450
460
],
451
461
method_replacement = {"forward" : get_lm_forward_with_dist_cross_entropy (self .shard_config )},
452
462
)
You can’t perform that action at this time.
0 commit comments