44import numpy as np
55
66from lightllm .models .starcoder .layer_weights .pre_and_post_layer_weight import PreAndPostLayerWeight
7- from lightllm .models . starcoder .infer_struct import StarcoderInferStateInfo
7+ from lightllm .common . basemodel .infer_struct import InferStateInfo
88from lightllm .utils .infer_utils import mark_cost_time
99from lightllm .common .basemodel import PreLayerInfer
1010from lightllm .models .llama .triton_kernel .embedding import embedding
@@ -23,7 +23,7 @@ def __init__(self, network_config, mode):
2323 self .vob_start_id_ = self .tp_vocab_size_ * self .tp_rank_
2424 self .vob_end_id_ = self .tp_vocab_size_ * (self .tp_rank_ + 1 )
2525
26- def context_forward (self , input_ids , infer_state : StarcoderInferStateInfo , layer_weight : PreAndPostLayerWeight ):
26+ def context_forward (self , input_ids , infer_state : InferStateInfo , layer_weight : PreAndPostLayerWeight ):
2727 total_token_num = infer_state .total_token_num
2828 input_ids = input_ids [0 :total_token_num ]
2929
@@ -43,7 +43,7 @@ def context_forward(self, input_ids, infer_state: StarcoderInferStateInfo, layer
4343
4444 return input_embdings .add_ (position_embeds )
4545
46- def token_forward (self , input_ids , infer_state : StarcoderInferStateInfo , layer_weight : PreAndPostLayerWeight ):
46+ def token_forward (self , input_ids , infer_state : InferStateInfo , layer_weight : PreAndPostLayerWeight ):
4747 # import ipdb;ipdb.set_trace()
4848 input_embdings = self .alloc_tensor (
4949 (input_ids .shape [0 ], layer_weight .wte_weight_ .shape [1 ]), dtype = layer_weight .data_type_
0 commit comments