21
21
22
22
from __future__ import annotations
23
23
24
- import contextlib
25
- import math
26
- import os
27
- import warnings
28
24
from functools import partial
29
- from typing import List , Optional , Tuple , Union
30
25
31
26
import paddle
32
- import paddle .distributed as dist
33
- import paddle .distributed .fleet .meta_parallel as mpu
34
27
import paddle .nn .functional as F
35
28
from paddle import Tensor , nn
36
29
from paddle .distributed import fleet
37
- from paddle .distributed .fleet .meta_parallel import get_rng_state_tracker
38
- from paddle .distributed .fleet .recompute .recompute import recompute
39
30
from paddle .jit import to_static
40
- from paddle .nn import BCEWithLogitsLoss , CrossEntropyLoss , MSELoss
41
31
from paddle .utils import try_import
42
32
43
33
try :
48
38
try :
49
39
from paddle .distributed .fleet .utils .sequence_parallel_utils import (
50
40
GatherOp ,
51
- ScatterOp ,
52
41
mark_as_sequence_parallel_parameter ,
53
42
)
54
43
except :
62
51
flash_attention = None
63
52
64
53
from config .configuration import DeepseekV2FastConfig
65
- from moe_gate import PretrainedMoEGate
66
- from moe_layer import MoEFlexTokenLayer , MoELayer
67
54
from paddle .distributed .fleet .meta_parallel .zero_bubble_utils import WeightGradStore
68
55
69
- from paddleformers .transformers .activations import ACT2FN
70
- from paddleformers .transformers .conversion_utils import (
71
- StateDictNameMapping ,
72
- init_name_mappings ,
73
- )
74
- from paddleformers .transformers .deepseek_v2 import DeepseekV2RotaryEmbedding , Linear
75
- from paddleformers .transformers .deepseek_v2 import fp8_linear as linear_utils
76
56
from paddleformers .transformers .deepseek_v2 import (
57
+ DeepseekV2RotaryEmbedding ,
77
58
yarn_find_correction_range ,
78
59
yarn_get_mscale ,
79
60
yarn_linear_ramp_mask ,
80
61
)
81
62
from paddleformers .transformers .fp8_utils import (
82
- FP8Linear ,
83
63
FP8LinearFunctionBase ,
84
64
cache_fp8_weight ,
85
65
set_parameter_color ,
86
66
)
87
- from paddleformers .transformers .llama import fusion_ops
88
- from paddleformers .transformers .llama .modeling import get_use_casual_mask
89
- from paddleformers .transformers .model_outputs import (
90
- BaseModelOutputWithPastAndMTP ,
91
- CausalLMOutputWithPast ,
92
- SequenceClassifierOutputWithPast ,
93
- )
94
- from paddleformers .transformers .model_utils import (
95
- PretrainedModel ,
96
- dtype_guard ,
97
- register_base_model ,
98
- )
99
- from paddleformers .transformers .utils import cast_if_needed , device_guard
100
- from paddleformers .utils .initializer import kaiming_uniform_
101
- from paddleformers .utils .log import logger
67
+ from paddleformers .transformers .utils import device_guard
102
68
from paddleformers .utils .tools import get_env_device
103
69
104
70
try :
@@ -117,13 +83,7 @@ def swiglu(x, y=None):
117
83
except ImportError :
118
84
fused_partial_rope = None
119
85
120
- from paddleformers .transformers .deepseek_v2 import (
121
- DeepseekV2ForCausalLM ,
122
- DeepseekV2ForSequenceClassification ,
123
- DeepseekV2Model ,
124
- DeepseekV2PretrainedModel ,
125
- DeepseekV2PretrainingCriterion ,
126
- )
86
+ from paddleformers .transformers .deepseek_v2 import rotate_half
127
87
128
88
__all__ = [
129
89
"DeepseekV2LMHead" ,
@@ -153,6 +113,13 @@ def rms_norm_fused(x_in, w, eps, use_fast_ln=False):
153
113
return fused_ln .fused_rms_norm (x_in , w , eps )[0 ]
154
114
155
115
116
+ def cast_if_needed (x , dtype ):
117
+ """
118
+ cast_if_needed
119
+ """
120
+ return x .cast (dtype ) if x .dtype != dtype else x
121
+
122
+
156
123
def fusion_rms_norm (hidden_states , weight , variance_epsilon , use_fast_ln = False ):
157
124
if get_env_device () == "npu" :
158
125
return paddle .base .core .eager ._run_custom_op ("rms_norm_npu" , hidden_states , weight , variance_epsilon )[0 ]
0 commit comments