|
| 1 | +import torch |
| 2 | +import numpy as np |
| 3 | +from typing import List, Tuple |
| 4 | +from lightllm.server.router.model_infer.mode_backend.base_backend import ModeBackend |
| 5 | +from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end |
| 6 | +from lightllm.utils.log_utils import init_logger |
| 7 | +from lightllm.server.router.model_infer.infer_batch import g_infer_context |
| 8 | +from lightllm.server.router.model_infer.mode_backend.generic_pre_process import ( |
| 9 | + prepare_prefill_inputs, |
| 10 | +) |
| 11 | +from lightllm.server.router.model_infer.mode_backend.mtp_pre_process import ( |
| 12 | + prepare_mtp_prefill_inputs, |
| 13 | + prepare_draft_main_model_decode_inputs, |
| 14 | +) |
| 15 | +from lightllm.server.router.model_infer.mode_backend.generic_post_process import sample |
| 16 | +import os |
| 17 | +from lightllm.common.basemodel.infer_lock import g_infer_state_lock |
| 18 | +from lightllm.server.router.model_infer.infer_batch import InferReq |
| 19 | +from lightllm.server.router.model_infer.mode_backend.continues_batch.impl_mtp import ContinuesBatchWithMTPBackend |
| 20 | +import copy |
| 21 | +from lightllm.utils.dist_utils import device0_print |
| 22 | + |
| 23 | + |
| 24 | +logger = init_logger(__name__) |
| 25 | + |
| 26 | + |
| 27 | +class ChunkedPrefillWithMTPBackend(ContinuesBatchWithMTPBackend): |
| 28 | + def __init__(self) -> None: |
| 29 | + super().__init__() |
| 30 | + |
| 31 | + def decode(self): |
| 32 | + uninit_reqs, aborted_reqs, ok_finished_reqs, prefill_reqs, decode_reqs = self._get_classed_reqs( |
| 33 | + g_infer_context.infer_req_ids |
| 34 | + ) |
| 35 | + |
| 36 | + if aborted_reqs: |
| 37 | + g_infer_context.filter_reqs(aborted_reqs) |
| 38 | + |
| 39 | + if prefill_reqs: |
| 40 | + model_input, run_reqs = prepare_prefill_inputs( |
| 41 | + prefill_reqs, is_chuncked_mode=False, is_multimodal=self.is_multimodal |
| 42 | + ) |
| 43 | + model_output = self.model.forward(model_input) |
| 44 | + |
| 45 | + self._overlap_req_init_and_filter( |
| 46 | + uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True |
| 47 | + ) |
| 48 | + |
| 49 | + next_token_ids, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) |
| 50 | + next_token_ids = next_token_ids.detach().cpu().numpy() |
| 51 | + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() |
| 52 | + self._post_handle( |
| 53 | + run_reqs, next_token_ids, next_token_logprobs, is_chuncked_mode=False, do_filter_finished_reqs=False |
| 54 | + ) |
| 55 | + # spec prefill: MTP |
| 56 | + last_input_ids_cpu = None |
| 57 | + draft_model_input = model_input |
| 58 | + last_hidden_states = model_output.hidden_states |
| 59 | + for draft_model_idx in range(self.spec_step): |
| 60 | + device0_print(f"main {draft_model_input}") |
| 61 | + draft_model_input, last_input_ids_cpu = prepare_mtp_prefill_inputs( |
| 62 | + prefill_reqs, model_input, last_hidden_states, next_token_ids, last_input_ids_cpu |
| 63 | + ) |
| 64 | + device0_print(f"draft_model_input {draft_model_input}") |
| 65 | + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) |
| 66 | + draft_next_token_ids, _ = sample(draft_model_output.logits, run_reqs, self.eos_id) |
| 67 | + draft_next_token_ids = draft_next_token_ids.detach().cpu().numpy() |
| 68 | + |
| 69 | + last_hidden_states = draft_model_output.hidden_states |
| 70 | + next_token_ids = draft_next_token_ids |
| 71 | + self._save_draft_token_ids(draft_next_token_ids, run_reqs, draft_model_idx) |
| 72 | + |
| 73 | + if decode_reqs: |
| 74 | + model_input, run_reqs, mem_indexes_cpu = prepare_draft_main_model_decode_inputs( |
| 75 | + decode_reqs, self.draft_token_id_map |
| 76 | + ) |
| 77 | + model_output = self.model.forward(model_input) |
| 78 | + assert model_output.logits.shape[0] % self.spec_stride == 0 |
| 79 | + |
| 80 | + self._overlap_req_init_and_filter( |
| 81 | + uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True |
| 82 | + ) |
| 83 | + |
| 84 | + next_token_ids_cuda, next_token_probs = sample(model_output.logits, run_reqs, self.eos_id) |
| 85 | + next_token_ids = next_token_ids_cuda.detach().cpu().numpy() |
| 86 | + next_token_logprobs = torch.log(next_token_probs).detach().cpu().numpy() |
| 87 | + |
| 88 | + # verify |
| 89 | + accepted_reqs, accepted_index, need_free_mem_indexes = self.verify( |
| 90 | + next_token_ids, run_reqs, mem_indexes_cpu |
| 91 | + ) |
| 92 | + self._post_handle( |
| 93 | + accepted_reqs, |
| 94 | + next_token_ids[accepted_index], |
| 95 | + next_token_logprobs[accepted_index], |
| 96 | + is_chuncked_mode=False, |
| 97 | + do_filter_finished_reqs=False, |
| 98 | + ) |
| 99 | + # share some inference info with the main model |
| 100 | + draft_model_input = model_input |
| 101 | + draft_model_input.input_ids = next_token_ids_cuda |
| 102 | + draft_model_input.hidden_states = model_output.hidden_states |
| 103 | + # process the draft model output |
| 104 | + for draft_model_idx in range(self.spec_step): |
| 105 | + # spec decode: MTP |
| 106 | + draft_model_output = self.draft_models[draft_model_idx].forward(draft_model_input) |
| 107 | + draft_next_token_ids, _ = sample(draft_model_output.logits, run_reqs, self.eos_id) |
| 108 | + # prepare inputs for the next draft model |
| 109 | + draft_model_input.input_ids = draft_next_token_ids |
| 110 | + draft_model_input.hidden_states = draft_model_output.hidden_states |
| 111 | + draft_next_token_ids_numpy = draft_next_token_ids.detach().cpu().numpy() |
| 112 | + self._save_draft_token_ids(draft_next_token_ids_numpy, run_reqs, draft_model_idx) |
| 113 | + |
| 114 | + if need_free_mem_indexes: |
| 115 | + g_infer_state_lock.acquire() |
| 116 | + g_infer_context.req_manager.mem_manager.free(need_free_mem_indexes) |
| 117 | + g_infer_state_lock.release() |
| 118 | + |
| 119 | + self._overlap_req_init_and_filter(uninit_reqs=uninit_reqs, ok_finished_reqs=ok_finished_reqs, clear_list=True) |
| 120 | + return |
| 121 | + |
| 122 | + def verify(self, next_token_ids, run_reqs, draft_mem_indexes): |
| 123 | + accepted_reqs = [] |
| 124 | + accepted_index = [] |
| 125 | + need_free_mem_indexes = [] |
| 126 | + assert next_token_ids.shape[0] % self.spec_stride == 0 |
| 127 | + |
| 128 | + for i, req in enumerate(run_reqs): |
| 129 | + # main model output |
| 130 | + if i % self.spec_stride == 0: |
| 131 | + accepted_reqs.append(req) |
| 132 | + accepted_index.append(i) |
| 133 | + continue |
| 134 | + draft_model_idx = i % self.spec_stride - 1 |
| 135 | + if ( |
| 136 | + self.draft_token_id_map[req.req_idx][draft_model_idx] == next_token_ids[i - 1] |
| 137 | + and req.cur_accepted_len == draft_model_idx |
| 138 | + ): |
| 139 | + accepted_reqs.append(req) |
| 140 | + accepted_index.append(i) |
| 141 | + req.cur_accepted_len += 1 |
| 142 | + device0_print(f"req {req.req_idx} accepted, cur_accepted_len {req.cur_accepted_len}") |
| 143 | + else: |
| 144 | + need_free_mem_indexes.append(draft_mem_indexes[i]) |
| 145 | + return accepted_reqs, accepted_index, need_free_mem_indexes |
| 146 | + |
| 147 | + def _save_draft_token_ids(self, draft_next_token_ids, run_reqs, draft_model_idx): |
| 148 | + batch_size = len(run_reqs) // self.spec_stride |
| 149 | + for i in range(batch_size): |
| 150 | + req = run_reqs[self.spec_stride * i] |
| 151 | + self.draft_token_id_map[req.req_idx][draft_model_idx] = draft_next_token_ids[i + req.cur_accepted_len] |
| 152 | + # reset the cur_accepted_len |
| 153 | + if draft_model_idx == self.spec_step - 1: |
| 154 | + req.cur_accepted_len = 0 |
0 commit comments