Skip to content

Commit a783b04

Browse files
committed
fix cohere
1 parent 7eef615 commit a783b04

File tree

1 file changed

+4
-5
lines changed

1 file changed

+4
-5
lines changed

lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,10 @@ def _bind_rotary_emb_fwd(self):
4444
def _get_qkv(
4545
self, input, infer_state: InferStateInfo, layer_weight
4646
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
47-
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
48-
cache_kv = torch.mm(
49-
input.view(-1, self.embed_dim_),
50-
layer_weight.kv_weight_,
51-
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
47+
q = layer_weight.q_proj.mm(input.view(-1, self.embed_dim_))
48+
cache_kv = layer_weight.kv_proj.mm(input.view(-1, self.embed_dim_)).view(
49+
-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_
50+
)
5251

5352
if self.use_qk_norm_:
5453
q = q.view(-1, self.tp_q_head_num_, self.head_dim_)

0 commit comments

Comments
 (0)