21
21
from colossalai .pipeline .stage_manager import PipelineStageManager
22
22
from colossalai .shardformer .layer import ColoAttention
23
23
from colossalai .shardformer .shard import ShardConfig
24
+
24
25
from ..layer import cross_entropy_1d
26
+
25
27
logger = logging .get_logger (__name__ )
26
28
27
29
@@ -351,7 +353,7 @@ def opt_for_causal_lm_forward(
351
353
loss_fct = CrossEntropyLoss ()
352
354
shift_logits = shift_logits .view (- 1 , self .config .vocab_size )
353
355
loss = loss_fct (shift_logits .view (- 1 , self .config .vocab_size ), shift_labels .view (- 1 ))
354
-
356
+
355
357
if not return_dict :
356
358
output = (logits ,) + outputs [1 :]
357
359
return (loss ,) + output if loss is not None else output
@@ -987,8 +989,8 @@ def forward(
987
989
process_group = shard_config .tensor_parallel_process_group ,
988
990
vocab_size = self .lm_head .out_features ,
989
991
)
990
- #loss_fct = CrossEntropyLoss()
991
- #loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
992
+ # loss_fct = CrossEntropyLoss()
993
+ # loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1))
992
994
993
995
if not return_dict :
994
996
output = (logits ,) + outputs [1 :]
@@ -1002,4 +1004,4 @@ def forward(
1002
1004
attentions = outputs .attentions ,
1003
1005
)
1004
1006
1005
- return forward
1007
+ return forward
0 commit comments