Skip to content

Commit feb505b

Browse files
author
sangchengmeng
committed
fix rms_norm
1 parent 58b7fd4 commit feb505b

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

lightllm/models/llama/layer_infer/post_layer_infer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
99
from einops import rearrange
1010
from lightllm.models.llama.infer_struct import LlamaInferStateInfo
11-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
11+
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
1212
from lightllm.common.basemodel import PostLayerInferTpl
1313
from lightllm.utils.infer_utils import mark_cost_time
1414
from lightllm.distributed.communication_op import all_gather
@@ -25,7 +25,7 @@ def __init__(self, network_config, mode):
2525
return
2626

2727
def _norm(self, input, infer_state, layer_weight: LlamaPreAndPostLayerWeight) -> torch.Tensor:
28-
return rmsnorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_)
28+
return rms_norm(input, layer_weight.final_norm_weight_, eps=self.eps_, use_custom_tensor_mananger=True)
2929

3030
def _slice_get_last_input(self, input_embdings, infer_state: LlamaInferStateInfo):
3131

lightllm/models/llama/layer_infer/transformer_layer_infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from lightllm.models.llama.triton_kernel.token_attention_nopad_att1 import token_att_fwd, token_att_fwd_int8k
1515
from lightllm.models.llama.triton_kernel.token_attention_nopad_softmax import token_softmax_fwd
1616
from lightllm.models.llama.triton_kernel.token_attention_nopad_reduceV import token_att_fwd2, token_att_fwd2_int8v
17-
from lightllm.models.llama.triton_kernel.rmsnorm import rmsnorm_forward
17+
from lightllm.models.vit.triton_kernel.rms_norm_vit import rms_norm
1818
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
1919
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd
2020

@@ -135,14 +135,14 @@ def _att_norm(
135135
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
136136
) -> torch.Tensor:
137137
out = self.alloc_tensor(input.shape, input.dtype)
138-
rmsnorm_forward(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, out=out)
138+
rms_norm(input, weight=layer_weight.att_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
139139
return out
140140

141141
def _ffn_norm(
142142
self, input, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
143143
) -> torch.Tensor:
144144
out = self.alloc_tensor(input.shape, input.dtype)
145-
rmsnorm_forward(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, out=out)
145+
rms_norm(input, weight=layer_weight.ffn_norm_weight_.weight, eps=self.eps_, use_custom_tensor_mananger=True)
146146
return out
147147

148148
def _get_qkv(

0 commit comments

Comments
 (0)