Skip to content

Commit ab79bb3

Browse files
committed
fix input/output of the other mode
1 parent 4a34160 commit ab79bb3

File tree

11 files changed

+42
-32
lines changed

11 files changed

+42
-32
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(self, kvargs):
7676

7777
# Speculative decoding
7878
self.spec_algo = SpeculativeDecodeAlgorithm.from_string(kvargs.get("spec_algo", "NONE"))
79+
self.spec_step = kvargs.get("spec_step", 1)
7980

8081
self._init_datatype()
8182
self._init_config()

lightllm/common/basemodel/cuda_graph.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,8 +130,12 @@ def replay(self, input_ids, infer_state, input_ids1=None, infer_state1=None):
130130
@torch.no_grad()
131131
def warmup(self, model):
132132
logger.info("Begin capture cudagraph, use the --disable_cudagraph to disable it.")
133-
decode_len = model.spec_algo.decode_len()
134-
for batch_size in range(self.max_batch_size, 0, -1):
133+
if model.spec_algo is not None:
134+
spec_stride = model.spec_step + 1
135+
else:
136+
spec_stride = 1
137+
138+
for batch_size in range(self.max_batch_size * spec_stride, 0, -1 * spec_stride):
135139
# dummy prefill
136140
prefill_input_len = 1
137141
dummy_input_ids = torch.ones((batch_size,), dtype=torch.int32, device="cuda")

lightllm/server/router/model_infer/mode_backend/base_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ def init_model(self, kvargs):
111111
"quant_type": kvargs.get("quant_type", None),
112112
"quant_cfg": kvargs.get("quant_cfg", None),
113113
"spec_algo": kvargs.get("spec_algo", "NONE"),
114+
"spec_step": kvargs.get("spec_step", 1),
114115
"run_mode": self.run_mode,
115116
}
116117
self.model, self.is_multimodal = get_model(model_cfg, model_kvargs)

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_first_token_constraint_mode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ def decode(self):
4141

4242
# 先 decode
4343
if decode_reqs:
44-
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
45-
logits = self.model.forward(**kwargs)
44+
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
45+
logits = self.model.forward(model_input)
4646
self._overlap_req_init_and_filter(
4747
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
4848
)
@@ -59,10 +59,10 @@ def decode(self):
5959
if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0):
6060
if prefill_reqs:
6161
self.need_prefill_count -= 1
62-
kwargs, run_reqs = prepare_prefill_inputs(
62+
model_input, run_reqs = prepare_prefill_inputs(
6363
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
6464
)
65-
logits = self.model.forward(**kwargs)
65+
logits = self.model.forward(model_input)
6666
self._overlap_req_init_and_filter(
6767
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
6868
)

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_outlines_constraint_mode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def decode(self):
5454

5555
# 先 decode
5656
if decode_reqs:
57-
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
58-
logits = self.model.forward(**kwargs)
57+
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
58+
logits = self.model.forward(model_input)
5959
self._overlap_req_init_and_filter(
6060
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
6161
)
@@ -85,10 +85,10 @@ def decode(self):
8585
if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0):
8686
if prefill_reqs:
8787
self.need_prefill_count -= 1
88-
kwargs, run_reqs = prepare_prefill_inputs(
88+
model_input, run_reqs = prepare_prefill_inputs(
8989
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
9090
)
91-
logits = self.model.forward(**kwargs)
91+
logits = self.model.forward(model_input)
9292
self._overlap_req_init_and_filter(
9393
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
9494
)

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_token_healing.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def decode(self):
5050

5151
# 先 decode
5252
if decode_reqs:
53-
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
54-
logits = self.model.forward(**kwargs)
53+
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
54+
logits = self.model.forward(model_input)
5555
self._overlap_req_init_and_filter(
5656
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
5757
)
@@ -83,10 +83,10 @@ def decode(self):
8383
if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0):
8484
if prefill_reqs:
8585
self.need_prefill_count -= 1
86-
kwargs, run_reqs = prepare_prefill_inputs(
86+
model_input, run_reqs = prepare_prefill_inputs(
8787
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
8888
)
89-
logits = self.model.forward(**kwargs)
89+
logits = self.model.forward(model_input)
9090
self._overlap_req_init_and_filter(
9191
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
9292
)

lightllm/server/router/model_infer/mode_backend/chunked_prefill/impl_for_xgrammar_mode.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,8 @@ def decode(self):
4848

4949
# 先 decode
5050
if decode_reqs:
51-
kwargs, run_reqs = prepare_decode_inputs(decode_reqs)
52-
logits = self.model.forward(**kwargs)
51+
model_input, run_reqs = prepare_decode_inputs(decode_reqs)
52+
logits = self.model.forward(model_input)
5353
self._overlap_req_init_and_filter(
5454
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
5555
)
@@ -79,10 +79,10 @@ def decode(self):
7979
if len(decode_reqs) == 0 or (self.forward_step % self.max_wait_step == 0) or (self.need_prefill_count > 0):
8080
if prefill_reqs:
8181
self.need_prefill_count -= 1
82-
kwargs, run_reqs = prepare_prefill_inputs(
82+
model_input, run_reqs = prepare_prefill_inputs(
8383
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
8484
)
85-
logits = self.model.forward(**kwargs)
85+
logits = self.model.forward(model_input)
8686
self._overlap_req_init_and_filter(
8787
uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True
8888
)

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_return_all_prompt_logprobs.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,14 @@ def prefill(self, run_reqs: List[Tuple]):
1717
req_ids = self._init_reqs(run_reqs, init_req_obj=True)
1818

1919
req_objs = self._trans_req_ids_to_req_objs(req_ids)
20-
kwargs, run_reqs = prepare_prefill_inputs(req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal)
21-
22-
prompt_all_logits = self.model.forward(**kwargs)
23-
input_ids = kwargs["input_ids"]
24-
b_ready_cache_len = kwargs["b_ready_cache_len"]
25-
b_seq_len = kwargs["b_seq_len"]
20+
model_input, run_reqs = prepare_prefill_inputs(
21+
req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
22+
)
23+
24+
prompt_all_logits = self.model.forward(model_input)
25+
input_ids = model_input.input_ids
26+
b_ready_cache_len = model_input.b_ready_cache_len
27+
b_seq_len = model_input.b_seq_len
2628
last_index = torch.cumsum(b_seq_len, dim=0, dtype=torch.long) - 1
2729
logits = prompt_all_logits[last_index, :]
2830

lightllm/server/router/model_infer/mode_backend/continues_batch/impl_for_reward_model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@ def prefill(self, reqs: List[Tuple]):
1414
req_ids = self._init_reqs(reqs, init_req_obj=True)
1515

1616
req_objs = self._trans_req_ids_to_req_objs(req_ids)
17-
kwargs, run_reqs = prepare_prefill_inputs(req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal)
17+
model_input, run_reqs = prepare_prefill_inputs(
18+
req_objs, is_chuncked_mode=False, is_multimodal=self.is_multimodal
19+
)
1820

19-
scores: torch.Tensor = self.model.forward(**kwargs)
21+
scores: torch.Tensor = self.model.forward(model_input)
2022
scores = scores.unsqueeze(1).detach().cpu().float().numpy()
2123

2224
next_token_id = 1

lightllm/server/router/model_infer/mode_backend/continues_batch/pd_mode/prefill_node_impl/prefill_impl.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,11 +72,11 @@ def decode(self):
7272
self._filter_reqs(ok_finished_reqs)
7373

7474
if prefill_reqs:
75-
kwargs, run_reqs = prepare_prefill_inputs(
75+
model_input, run_reqs = prepare_prefill_inputs(
7676
prefill_reqs, is_chuncked_mode=True, is_multimodal=self.is_multimodal
7777
)
7878

79-
logits = self.model.forward(**kwargs)
79+
logits = self.model.forward(model_input)
8080
next_token_ids, next_token_probs = sample(logits, run_reqs, self.eos_id)
8181
next_token_ids = next_token_ids.detach().cpu().numpy()
8282
next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy()

0 commit comments

Comments
 (0)