File tree Expand file tree Collapse file tree 3 files changed +417
-11
lines changed
internlm/model/model_ops/ops Expand file tree Collapse file tree 3 files changed +417
-11
lines changed Original file line number Diff line number Diff line change 1818 CrossEntropyApexVocabParallel ,
1919 CrossEntropyLossApex ,
2020 CrossEntropyPython ,
21+ CrossEntropyLossFlash ,
2122)
2223from internlm .utils .logger import get_logger
2324
@@ -86,17 +87,8 @@ def new_cross_entropy(
8687
8788 assert gpc .get_group (ParallelMode .TENSOR ) is not None , "The process group should not be None."
8889
89- try :
90- from flash_attn .losses .cross_entropy import (
91- CrossEntropyLoss as FlashCrossEntropyLoss ,
92- )
93-
94- flash_cross_entropy_impl = True
95- except (ModuleNotFoundError , ImportError ):
96- flash_cross_entropy_impl = False
97-
9890 assert (
99- gpc .config .model .get ("use_flash_attn" , False ) and flash_cross_entropy_impl
91+ gpc .config .model .get ("use_flash_attn" , False )
10092 ), "Only flash cross entropy support parallel_output"
10193
10294 assert (
@@ -108,7 +100,7 @@ def new_cross_entropy(
108100 which may result loss divergency in long sequence."
109101 )
110102
111- return FlashCrossEntropyLoss (
103+ return CrossEntropyLossFlash (
112104 ignore_index = ignore_index ,
113105 reduction = reduction ,
114106 label_smoothing = label_smoothing ,
Original file line number Diff line number Diff line change 22from .py_naive_loss import CrossEntropyPython
33from .py_vocab_parallel_loss import CrossEntropyApexVocabParallel
44from .sequence_parallel_loss import VocabSequenceParallelCrossEntropyLoss
5+ from .flash_loss import CrossEntropyLossFlash
56
67__all__ = [
78 "CrossEntropyLossApex" ,
89 "CrossEntropyPython" ,
910 "CrossEntropyApexVocabParallel" ,
1011 "VocabSequenceParallelCrossEntropyLoss" ,
12+ "CrossEntropyLossFlash" ,
1113]
You can’t perform that action at this time.
0 commit comments