Skip to content

Commit 469c180

Browse files
author
wangzaijun
committed
fix
1 parent 90910f0 commit 469c180

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

lightllm/models/bloom/layer_infer/transformer_layer_infer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import torch.functional as F
44
import torch.distributed as dist
55
import numpy as np
6-
6+
from typing import Tuple
77
from lightllm.common.basemodel import TransformerLayerInferTpl
88
from lightllm.models.bloom.layer_weights.transformer_layer_weight import BloomTransformerLayerWeight
99
from lightllm.models.bloom.triton_kernel.context_flashattention_nopad import context_attention_fwd
@@ -43,7 +43,9 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTrans
4343
eps=self.eps_,
4444
)
4545

46-
def _get_qkv(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
46+
def _get_qkv(
47+
self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
48+
) -> Tuple[torch.Tensor, torch.Tensor]:
4749
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
4850
cache_kv = self._pre_cache_kv(infer_state=infer_state, layer_weight=layer_weight)
4951
cache_kv = layer_weight.kv_proj.mm(

0 commit comments

Comments
 (0)