Skip to content

Commit a3d74d1

Browse files
committed
fix
1 parent 9042b32 commit a3d74d1

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lightllm/models/starcoder/layer_infer/pre_layer_infer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55

66
from lightllm.models.starcoder.layer_weights.pre_and_post_layer_weight import PreAndPostLayerWeight
7-
from lightllm.models.starcoder.infer_struct import StarcoderInferStateInfo
7+
from lightllm.common.basemodel.infer_struct import InferStateInfo
88
from lightllm.utils.infer_utils import mark_cost_time
99
from lightllm.common.basemodel import PreLayerInfer
1010
from lightllm.models.llama.triton_kernel.embedding import embedding
@@ -23,7 +23,7 @@ def __init__(self, network_config, mode):
2323
self.vob_start_id_ = self.tp_vocab_size_ * self.tp_rank_
2424
self.vob_end_id_ = self.tp_vocab_size_ * (self.tp_rank_ + 1)
2525

26-
def context_forward(self, input_ids, infer_state: StarcoderInferStateInfo, layer_weight: PreAndPostLayerWeight):
26+
def context_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: PreAndPostLayerWeight):
2727
total_token_num = infer_state.total_token_num
2828
input_ids = input_ids[0:total_token_num]
2929

@@ -43,7 +43,7 @@ def context_forward(self, input_ids, infer_state: StarcoderInferStateInfo, layer
4343

4444
return input_embdings.add_(position_embeds)
4545

46-
def token_forward(self, input_ids, infer_state: StarcoderInferStateInfo, layer_weight: PreAndPostLayerWeight):
46+
def token_forward(self, input_ids, infer_state: InferStateInfo, layer_weight: PreAndPostLayerWeight):
4747
# import ipdb;ipdb.set_trace()
4848
input_embdings = self.alloc_tensor(
4949
(input_ids.shape[0], layer_weight.wte_weight_.shape[1]), dtype=layer_weight.data_type_

0 commit comments

Comments
 (0)