1212from lightllm .utils .dist_utils import get_current_rank_in_dp , get_dp_world_size
1313from lightllm .models .vit .triton_kernel .gelu_vit import gelu_fwd
1414from lightllm .models .vit .triton_kernel .rms_norm_vit import rms_norm
15+ from lightllm .common .basemodel .layer_infer .cache_tensor_manager import g_cache_manager
1516
1617
1718class ViTTransformerLayerInfer :
@@ -60,7 +61,9 @@ def tp_norm(self, input, weight):
6061
6162 def _att_norm (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
6263 if layer_weight .norm_type == "rms_norm" :
63- b = rms_norm (input , weight = layer_weight .att_norm_weight_ .weight , eps = self .eps_ )
64+ b = rms_norm (
65+ input , weight = layer_weight .att_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
66+ )
6467 else :
6568 b = torch .nn .functional .layer_norm (
6669 input ,
@@ -73,7 +76,9 @@ def _att_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
7376
7477 def _ffn_norm (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
7578 if layer_weight .norm_type == "rms_norm" :
76- return rms_norm (input , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_ )
79+ return rms_norm (
80+ input , weight = layer_weight .ffn_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
81+ )
7782 else :
7883 return torch .nn .functional .layer_norm (
7984 input ,
@@ -84,20 +89,28 @@ def _ffn_norm(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Ten
8489 )
8590
8691 def _qk_norm (self , q , k , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
87- q_norm = self .tp_norm (q , layer_weight .q_norm_weight_ .weight )
88- k_norm = self .tp_norm (k , layer_weight .k_norm_weight_ .weight )
92+ if self .tp_world_size_ > 1 :
93+ q_norm = self .tp_norm (q , layer_weight .q_norm_weight_ .weight )
94+ k_norm = self .tp_norm (k , layer_weight .k_norm_weight_ .weight )
95+ else :
96+ q_norm = rms_norm (
97+ q , weight = layer_weight .q_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
98+ )
99+ k_norm = rms_norm (
100+ k , weight = layer_weight .k_norm_weight_ .weight , eps = self .eps_ , use_custom_tensor_mananger = True
101+ )
89102 return q_norm , k_norm
90103
91104 def _get_qkv (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
92105 batch_size = input .shape [0 ]
93106 seq_len = input .shape [1 ]
94- qkv = layer_weight .qkv_proj .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = False )
107+ qkv = layer_weight .qkv_proj .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = True )
95108 qkv = qkv .view (batch_size , seq_len , 3 , - 1 , self .head_dim_ )
96109 q , k , v = qkv .unbind (2 )
97110 return q , k , v
98111
99112 def _context_attention_kernel (self , q , k , v ) -> torch .Tensor :
100- out = torch . empty_like ( q )
113+ out = g_cache_manager . alloc_tensor ( q . shape , q . dtype , device = q . device )
101114 batch_size = q .shape [0 ]
102115 seq_len = q .shape [1 ]
103116 flash_attention_fwd (q , k , v , out )
@@ -107,30 +120,33 @@ def _get_o(self, input, layer_weight: ViTTransformerLayerWeight) -> torch.Tensor
107120 batch_size = input .shape [0 ]
108121 seq_len = input .shape [1 ]
109122 o_tensor = layer_weight .o_proj .mm (
110- input .view (- 1 , self .tp_padding_head_num * self .head_dim_ ), use_custom_tensor_mananger = False
123+ input .view (- 1 , self .tp_padding_head_num * self .head_dim_ ), use_custom_tensor_mananger = True
111124 )
112125 if layer_weight .use_ls :
113- o_tensor *= layer_weight .ls1
126+ o_tensor . mul_ ( layer_weight .ls1 )
114127 return o_tensor .reshape ((batch_size , seq_len , - 1 ))
115128
116129 def _ffn (self , input , layer_weight : ViTTransformerLayerWeight ) -> torch .Tensor :
117- fc1 = layer_weight .ffn_1_proj_ .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = False )
118- # ffn1_out = torch.nn.functional.gelu(fc1)
119- ffn1_out = gelu_fwd (fc1 )
130+ fc1 = layer_weight .ffn_1_proj_ .mm (input .view (- 1 , self .embed_dim_ ), use_custom_tensor_mananger = True )
120131 input_shape = input .shape
121132 input = None
122- ffn2_out = layer_weight .ffn_2_proj_ .mm (ffn1_out , use_custom_tensor_mananger = False )
123- if layer_weight .use_ls :
124- ffn2_out *= layer_weight .ls2
133+ ffn1_out = gelu_fwd (fc1 , use_custom_tensor_mananger = True )
134+ ffn2_out = layer_weight .ffn_2_proj_ .mm (ffn1_out , use_custom_tensor_mananger = True )
125135 ffn1_out = None
136+ if layer_weight .use_ls :
137+ ffn2_out .mul_ (layer_weight .ls2 )
126138 return ffn2_out .reshape (input_shape )
127139
128140 def _context_attention (self , input_embding , layer_weight ):
129141 input1 = self ._att_norm (input_embding , layer_weight )
130142 q , k , v = self ._get_qkv (input1 , layer_weight )
143+ input1 = None
131144 if layer_weight .qk_norm :
132145 q , k = self ._qk_norm (q , k , layer_weight )
133146 o = self ._context_attention_kernel (q , k , v )
147+ q = None
148+ k = None
149+ v = None
134150 o = self ._get_o (o , layer_weight )
135151 if self .tp_world_size_ > 1 :
136152 dist .all_reduce (o , op = dist .ReduceOp .SUM , async_op = False )
0 commit comments