|
14 | 14 | from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k |
15 | 15 | from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd |
16 | 16 | from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v |
17 | | -from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward |
| 17 | +from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm |
18 | 18 | from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd |
19 | 19 | from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd |
20 | 20 |
|
@@ -135,14 +135,14 @@ def _att_norm( |
135 | 135 | self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight |
136 | 136 | ) -> torch.Tensor: |
137 | 137 | out = self.alloc_tensor(input.shape, input.dtype) |
138 | | - rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out) |
| 138 | + rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True) |
139 | 139 | return out |
140 | 140 |
|
141 | 141 | def _ffn_norm( |
142 | 142 | self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight |
143 | 143 | ) -> torch.Tensor: |
144 | 144 | out = self.alloc_tensor(input.shape, input.dtype) |
145 | | - rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out) |
| 145 | + rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True) |
146 | 146 | return out |
147 | 147 |
|
148 | 148 | def _get_qkv( |
|
0 commit comments