@@ -30,7 +30,7 @@ def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.T
3030 def _ffn_norm (self , input , infer_state : InferStateInfo , layer_weight ) -> torch .Tensor :
3131 raise Exception ("need to impl" )
3232
33- def _pre_cache_kv (self , infer_state : InferStateInfo , layer_weight ) -> Tuple [ torch .Tensor , torch . Tensor ] :
33+ def _pre_cache_kv (self , infer_state : InferStateInfo , layer_weight ) -> torch .Tensor :
3434 cache_kv = self .alloc_tensor (
3535 shape = infer_state .kv_buffer_shapedtype [0 ],
3636 dtype = infer_state .kv_buffer_shapedtype [1 ],
@@ -40,14 +40,10 @@ def _pre_cache_kv(self, infer_state: InferStateInfo, layer_weight) -> Tuple[torc
4040 )
4141 return cache_kv
4242
43- def _get_qkv (
44- self , input , infer_state : InferStateInfo , layer_weight
45- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
43+ def _get_qkv (self , input , infer_state : InferStateInfo , layer_weight ) -> Tuple [torch .Tensor , torch .Tensor ]:
4644 raise Exception ("need to impl" )
4745
48- def _tpsp_get_qkv (
49- self , input , infer_state : InferStateInfo , layer_weight
50- ) -> Tuple [torch .Tensor , torch .Tensor , torch .Tensor ]:
46+ def _tpsp_get_qkv (self , input , infer_state : InferStateInfo , layer_weight ) -> Tuple [torch .Tensor , torch .Tensor ]:
5147 raise Exception ("need to impl" )
5248
5349 def _post_cache_kv (self , cache_kv , infer_state : InferStateInfo , layer_weight ):
0 commit comments