Skip to content

Commit 90910f0

Browse files
author
wangzaijun
committed
fix
1 parent 7702edd commit 90910f0

File tree

1 file changed

+3
-7
lines changed

1 file changed

+3
-7
lines changed

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_template.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3030
def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
3131
raise Exception("need to impl")
3232

33-
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
33+
def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> torch.Tensor:
3434
cache_kv = self.alloc_tensor(
3535
shape=infer_state.kv_buffer_shapedtype[0],
3636
dtype=infer_state.kv_buffer_shapedtype[1],
@@ -40,14 +40,10 @@ def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torc
4040
)
4141
return cache_kv
4242

43-
def _get_qkv(
44-
self, input, infer_state: InferStateInfo, layer_weight
45-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
43+
def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
4644
raise Exception("need to impl")
4745

48-
def _tpsp_get_qkv(
49-
self, input, infer_state: InferStateInfo, layer_weight
50-
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
46+
def _tpsp_get_qkv(self, input, infer_state: InferStateInfo, layer_weight) -> Tuple[torch.Tensor, torch.Tensor]:
5147
raise Exception("need to impl")
5248

5349
def _post_cache_kv(self, cache_kv, infer_state: InferStateInfo, layer_weight):

0 commit comments

Comments
 (0)