diff --git a/llgtrt/src/async_exec.rs b/llgtrt/src/async_exec.rs index 5fa87c7..9b22243 100644 --- a/llgtrt/src/async_exec.rs +++ b/llgtrt/src/async_exec.rs @@ -12,8 +12,7 @@ use std::{ use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender}; use toktrie::{SimpleVob, TokEnv}; use trtllm_rs::{ - ClientReqId, Executor, ExecutorInit, MaskAllocator, ReqId, RequestInit, ResponseChunk, - TlcLogitsEntry, + ClientReqId, Executor, ExecutorInit, MaskAllocator, ReqId, RequestInit, Responder, ResponseChunk, TlcLogitsEntry }; use crate::{ @@ -78,10 +77,12 @@ impl Display for ReqData { pub struct AsyncExecutor { executor: Executor, + draft_executor: Option, n_vocab: usize, max_batch_size: usize, req_to_client: HashMap, req_data: HashMap, + n_draft_tokens: u32, // TODO prob better location for this, esp if dynamically setting between calls } static mut GLOBAL_ALLOCATOR: *const MaskAllocator = ptr::null(); @@ -364,16 +365,35 @@ impl AsyncExecutor { self.executor.cancel_request(req_id) } + pub fn has_draft_model(&self) -> bool { + self.draft_executor.is_some() + } + + pub fn n_draft_tokens(&self) -> u32 { + self.n_draft_tokens + } + pub fn new( cli_config: &CliConfig, config: &LlgTrtConfig, mut executor_init: ExecutorInit, + mut draft_executor_init: Option, ) -> Result<(Self, TokEnv, ChatBuilder)> { executor_init.logits_callback = Some(logits_processor); - let max_batch_size = executor_init.trt_params.max_batch_size as usize; + let mut max_batch_size = executor_init.trt_params.max_batch_size as usize; log::info!("new executor: max_batch_size={max_batch_size}"); let (executor, mut responder) = Executor::new(executor_init)?; + let (draft_executor, mut draft_responder) = if draft_executor_init.is_some() { + draft_executor_init.logits_callback = Some(logits_processor); + max_batch_size = draft_executor_init.trt_params.max_batch_size as usize; + log::info!("new draft executor: max_batch_size={max_batch_size}"); + let (executor, responder) = Executor::new(draft_executor_init)?; + (Some(executor), Some(responder)) + } else { + (None, None) + }; + // on non-0 ranks, this will just wait until the rank 0 exits and then exit the process executor.check_mpi(); @@ -384,50 +404,64 @@ impl AsyncExecutor { let res = Self { executor, + draft_executor, req_data: HashMap::new(), req_to_client: HashMap::new(), n_vocab, max_batch_size, }; - rayon::spawn(move || loop { - let resps = responder - .await_responses(std::time::Duration::from_millis(1)) - .unwrap(); - if resps.len() == 0 { - continue; - } + // TODO idk what this does rn, need to setup with draft_responder? + let receive_from_responder = |responder: Responder| -> FnOnce { + return || { + move || loop { + let resps = responder + .await_responses(std::time::Duration::from_millis(1)) + .unwrap(); - let mut exec = AsyncExecutor::lock(); - for resp in resps { - let req_id = resp.req_id; - if let Some(client_req_id) = exec.req_to_client.get(&req_id) { - let client_req_id = *client_req_id; - let rd = exec.req_data.get_mut(&client_req_id).unwrap(); - let is_req_final = resp.is_req_final; - let idx = resp.sequence_idx as usize; - - let mut r = StepResults { - response: resp, - logs: std::mem::take(&mut rd.logs), - final_llg: None, - }; - if rd.llgs.len() > 0 && r.response.finish_reason.is_some() { - r.final_llg = std::mem::take(&mut rd.llgs[idx]); + if resps.len() == 0 { + continue; } - if rd.tx.send(r).is_err() { - log::warn!("connection dropped; req={}", req_id); - let _ = exec.cancel_request(req_id); - } else if is_req_final { - // no more data coming from here - exec.drop_request_data(req_id); + + let mut exec = AsyncExecutor::lock(); + for resp in resps { + let req_id = resp.req_id; + if let Some(client_req_id) = exec.req_to_client.get(&req_id) { + let client_req_id = *client_req_id; + let rd = exec.req_data.get_mut(&client_req_id).unwrap(); + let is_req_final = resp.is_req_final; + let idx = resp.sequence_idx as usize; + + let mut r = StepResults { + response: resp, + logs: std::mem::take(&mut rd.logs), + final_llg: None, + }; + if rd.llgs.len() > 0 && r.response.finish_reason.is_some() { + r.final_llg = std::mem::take(&mut rd.llgs[idx]); + } + if rd.tx.send(r).is_err() { + log::warn!("connection dropped; req={}", req_id); + let _ = exec.cancel_request(req_id); + } else if is_req_final { + // no more data coming from here + exec.drop_request_data(req_id); + } + } else { + log::warn!("Response for unknown request: {:?}", req_id); + let _ = exec.executor.cancel_request(req_id); + } } - } else { - log::warn!("Response for unknown request: {:?}", req_id); - let _ = exec.executor.cancel_request(req_id); } } - }); + }; + + rayon::spawn(receive_from_responder(responder)()); + + if Some(x) = draft_responder { + rayon::spawn(receive_from_responder(x)()); + } + Ok((res, tok_env, chat_builder)) } @@ -435,11 +469,37 @@ impl AsyncExecutor { self.executor.can_enqueue_request() } + pub fn add_draft_request( + &mut self, + init: &RequestInit, + prompt_params: Option>, + llgs: Vec>, + ) -> Result<(ReqId, UnboundedReceiver)> { + debug_assert!(self.draft_executor.is_some()); + init.params. + self.add_request_to_executor( + self.draft_executor.expect("this should've been checked before calling this method"), + init, + prompt_params, + llgs + ) + } + pub fn add_request( &mut self, init: &RequestInit, prompt_params: Option>, llgs: Vec>, + ) -> Result<(ReqId, UnboundedReceiver)> { + self.add_request_to_executor(self.executor, init, prompt_params, llgs) + } + + fn add_request_to_executor( + &mut self, + executor: Executor, + init: &RequestInit, + prompt_params: Option>, + llgs: Vec>, ) -> Result<(ReqId, UnboundedReceiver)> { let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); @@ -451,7 +511,6 @@ impl AsyncExecutor { let pp = prompt_params.as_ref().map(|p| &p.tlc_prompt_params); - // we're locked here, so it's safe to insert only after enqueuing let req_id = self.executor.enqueue_request(init, pp)?; self.req_data.insert( client_req_id, diff --git a/llgtrt/src/config.rs b/llgtrt/src/config.rs index 6781448..2e466de 100644 --- a/llgtrt/src/config.rs +++ b/llgtrt/src/config.rs @@ -82,6 +82,7 @@ impl Default for TrtLlmRuntimeConfig { cross_kv_cache_fraction: None, secondary_offload_min_priority: None, event_buffer_max_size: None, + n_draft_tokens: None } } } @@ -158,6 +159,14 @@ pub struct CliConfig { #[arg(long, short = 'T')] pub tokenizer: Option, + /// Path to a compiled TensorRT-LLM draft engine + #[arg(long)] + pub draft_engine: Option, + + /// Path to folder with HF tokenizer.json and tokenizer_config.json files for draft engine; defaults to ---draft-engine + #[arg(long)] + pub draft_tokenizer: Option, + /// Debug output #[arg(long, short = 'd')] pub debug: bool, diff --git a/llgtrt/src/main.rs b/llgtrt/src/main.rs index 394e49e..6bc58b5 100644 --- a/llgtrt/src/main.rs +++ b/llgtrt/src/main.rs @@ -25,7 +25,7 @@ async fn main() -> anyhow::Result<()> { // for every request when running with mpirun // env::set_var("TLLM_LOG_LEVEL", "INFO"); } - } + } llgtrt::logging::init_log(llgtrt::logging::LogMode::Normal)?; diff --git a/llgtrt/src/routes/completions.rs b/llgtrt/src/routes/completions.rs index d4d1829..703f90e 100644 --- a/llgtrt/src/routes/completions.rs +++ b/llgtrt/src/routes/completions.rs @@ -8,6 +8,7 @@ use axum::response::sse::{Event, Sse}; use axum::response::{IntoResponse, Response}; use axum::Json; use core::panic; +use std::iter::zip; use futures_core::Stream; use llguidance::api::{GrammarWithLexer, TopLevelGrammar}; use llguidance::Constraint; @@ -19,7 +20,7 @@ use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use tokio::sync::mpsc::UnboundedReceiver; use toktrie::TokEnv; -use trtllm_rs::{ClientReqId, LoraParams, ReqId, RequestInit, RequestParams, Tensor, TlcDataType}; +use trtllm_rs::{ClientReqId, DraftParams, LoraParams, ReqId, RequestInit, RequestParams, Tensor, TlcDataType}; use uuid::Uuid; use crate::async_exec::{map_finish_reason, AsyncExecutor, StepResults}; @@ -341,6 +342,7 @@ fn build_request_init( client_req_id, is_run, lora_params, + draft_params: None // TODO fill out }; Ok(request_init) } @@ -412,7 +414,7 @@ async fn mk_req_info( log::debug!("{}", app_state.tok_env.tok_trie().tokens_dbg(&tokens)); let eos_token = if is_chat { - app_state.tok_eos_chat + app_state.tok_eos_chatreq_input } else { app_state.tok_eos_completions }; @@ -484,7 +486,7 @@ async fn mk_req_info( LoadLoraWeightsOption::Always => true, LoadLoraWeightsOption::Never => false, }; - let req_init = build_request_init( + let mut req_init = build_request_init( tokens.clone(), req_params.clone(), client_req_id, @@ -495,74 +497,153 @@ async fn mk_req_info( )?; let prompt_tokens = req_init.tokens.len(); - let (req_id, recv) = AsyncExecutor::lock().add_request( - &req_init, - req_input.prompt_params.clone(), - llg.clone(), - )?; - - let info = build_req_info( - req_id, - n_forks, - client_req_id, - cmpl_id.clone(), - &prompt, - prompt_tokens, - params, - &app_state.tok_env, - is_chat, - is_run, - recv, - )?; + // TODO draft executor call + if AsyncExecutor::lock().has_draft_model() { + // TODO override n draft tokens to execute + let start_len = req_init.tokens.len(); + let n_gen_tokens = req_init.params.max_new_tokens; + let n_draft_tokens = AsyncExecutor::lock(); // TODO how long to do this + while req_init.tokens.len() < start_len + n_gen_tokens { + req_init.params.max_new_tokens = n_draft_toknes; // TODO set min? + let (req_id, recv) = AsyncExecutor::lock().add_draft_request( + &req_init, + req_input.prompt_params.clone(), // TODO needed here? + llg.clone(), + )?; + + let mut req_info = build_req_info( + req_id, + n_forks, + client_req_id, + cmpl_id.clone(), + &prompt, + prompt_tokens, + params, + &app_state.tok_env, + is_chat, + is_run, + recv, + )?; + + // TODO proper error handling, + // TODO need to pass full req_info + // TODO this doesn't need return passed reqinfo + req_info.usage.completion_tokens = n_draft_tokens; + let log_probs: Vec; + (req_info, log_probs) = gather_response_chunks(req_info).await?; + let (tokens, logits) = log_probs.iter().map(|top| (top.chosen.token, top.chosen.logprob)).unzip(); + req_init.draft_params = Some(DraftParams { + draft_tokens: tokens, + logits_tensor: logits // TODO init correctly + }); + + req_init.params.max_new_tokens = n_draft_tokens + 1; // TODO double check this + (req_id, recv) = AsyncExecutor::lock().add_request( + &req_init, + req_input.prompt_params.clone(), + llg.clone() + )?; + + req_info = build_req_info( + req_id, + n_forks, + client_req_id, + cmpl_id.clone(), + &prompt, + prompt_tokens, + params, + &app_state.tok_env, + is_chat, + is_run, + recv, + )?; + + (req_info, log_probs) = gather_response_chunks(req_info).await?; + let (target_tokens, _) = log_probs.iter().map(|top| (top.chosen.token, top.chosen.logprob)).unzip(); + // TODO double check gather_response_chunks return prompt tokens as well? + req_init = build_request_init( + req_init.tokens + target_tokens, + req_params.clone(), + client_req_id, + is_run, + ¶ms, + app_state, + load_lora_weights, + )?; + } - match completions_stream_or_not(params.stream, info).await { - Ok(r) => Ok(r), - Err(err) => { - if is_lora_cache_miss_error(&err) { - log::warn!("LoRA cache miss: {}", err); - - // If necessary, retry with LoRA weights set - if params.load_lora_weights == LoadLoraWeightsOption::Auto { - log::info!("Retrying with LoRA weights set"); - let req_init = build_request_init( - tokens.clone(), - req_params.clone(), - client_req_id, - is_run, - ¶ms, - app_state, - true, - )?; - let prompt_tokens = req_init.tokens.len(); - - let (req_id, recv) = AsyncExecutor::lock().add_request( - &req_init, - req_input.prompt_params.clone(), - llg.clone(), - )?; - - let info = build_req_info( - req_id, - n_forks, - client_req_id, - cmpl_id.clone(), - &prompt, - prompt_tokens, - params, - &app_state.tok_env, - is_chat, - is_run, - recv, - )?; - - completions_stream_or_not(params.stream, info).await + // TODO if req_info has gotten all info should skip inner gather loop? + completions_stream_or_not(False, req_info).await + } else { + let (req_id, recv) = AsyncExecutor::lock().add_request( + &req_init, + req_input.prompt_params.clone(), + llg.clone(), + )?; + + let info = build_req_info( + req_id, + n_forks, + client_req_id, + cmpl_id.clone(), + &prompt, + prompt_tokens, + params, + &app_state.tok_env, + is_chat, + is_run, + recv, + )?; + + match completions_stream_or_not(params.stream, info).await { + Ok(r) => Ok(r), + Err(err) => { + if is_lora_cache_miss_error(&err) { + log::warn!("LoRA cache miss: {}", err); + + // If necessary, retry with LoRA weights set + if params.load_lora_weights == LoadLoraWeightsOption::Auto { + log::info!("Retrying with LoRA weights set"); + let req_init = build_request_init( + tokens.clone(), + req_params.clone(), + client_req_id, + is_run, + ¶ms, + app_state, + true, + )?; + let prompt_tokens = req_init.tokens.len(); + + let (req_id, recv) = AsyncExecutor::lock().add_request( + &req_init, + req_input.prompt_params.clone(), + llg.clone(), + )?; + + let info = build_req_info( + req_id, + n_forks, + client_req_id, + cmpl_id.clone(), + &prompt, + prompt_tokens, + params, + &app_state.tok_env, + is_chat, + is_run, + recv, + )?; + + completions_stream_or_not(params.stream, info).await + } else { + Ok(AppError::from(anyhow!( + "LoRA model {:?} was not in cache and load_lora_weights is set to never: {}", params.lora_model, err + )).into_response()) + } } else { - Ok(AppError::from(anyhow!( - "LoRA model {:?} was not in cache and load_lora_weights is set to never: {}", params.lora_model, err - )).into_response()) + Ok(err.into_response()) } - } else { - Ok(err.into_response()) } } } @@ -983,8 +1064,7 @@ async fn completions_stream( Ok(Sse::new(response_stream)) } -async fn completions(mut client: ReqInfo) -> Result, AppError> { - let mut token = client.cancel_token(); +async fn gather_response_chunks(mut client: ReqInfo) -> Result<(ReqInfo, Vec), Error> { let mut logprobs = vec![]; while let Some(mut result) = client.recv.recv().await { log::trace!("infer response: {:?}", result.response); @@ -1009,6 +1089,36 @@ async fn completions(mut client: ReqInfo) -> Result, AppError> { } } + Ok((client, logprobs)) +} + +async fn completions(mut client: ReqInfo) -> Result, AppError> { + let mut token = client.cancel_token(); + // let mut logprobs = vec![]; + (client, log_probs) = gather_response_chunks(client)?; + // while let Some(mut result) = client.recv.recv().await { + // log::trace!("infer response: {:?}", result.response); + // let response = &result.response; + // if let Some(err) = &response.error { + // let err = anyhow::anyhow!("{}", err); + // log::error!("received error message (rest): {}", err); + // let _ = AsyncExecutor::lock().cancel_request(client.req_id); + // return Err(err.into()); + // } else { + // client.usage.completion_tokens += response.tokens.len(); + // client.usage.total_tokens += response.tokens.len(); + // let r = client.update_text(&mut result, false); + // if let Some(mut lp) = r.logprobs { + // logprobs.append(&mut lp.content); + // } + // } + + // if client.all_forks_stopped() { + // let _ = AsyncExecutor::lock().cancel_request(client.req_id); + // break; + // } + // } + let created = SystemTime::now().duration_since(UNIX_EPOCH)?.as_secs(); let id = format!("cmpl-{}", Uuid::new_v4()); diff --git a/llgtrt/src/startup.rs b/llgtrt/src/startup.rs index 8559184..7e7c970 100644 --- a/llgtrt/src/startup.rs +++ b/llgtrt/src/startup.rs @@ -46,6 +46,13 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { trt_params: Default::default(), }; + // TODO keep trt params same for now + let mut draft_exec_config = cli_config.draft_engine.clone().map(|engine_path| ExecutorInit { + engine_path: engine_path, + logits_callback: None, + trt_params: Default::default(), + }); + let defl_config_path = format!("{}/llgtrt.json5", cli_config.engine); if cli_config.config.is_empty() { if std::fs::exists(&defl_config_path).unwrap_or(false) { @@ -64,6 +71,7 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { if cli_config.print_config { log::info!("Skipping tokenizer config load"); } else { + // TODO target & draft must have same token tokenizer? do sanity check here? let tokenizer_folder = cli_config.tokenizer.as_ref().unwrap_or(&cli_config.engine); let tokenizer_config = format!("{}/tokenizer_config.json", tokenizer_folder); log::info!("Loading tokenizer config from {:?}", tokenizer_config); @@ -196,6 +204,12 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { set_field_opt!(event_buffer_max_size); p.kv_cache_host_memory_bytes = runtime_config.kv_cache_host_memory_megabytes * 1024 * 1024; + if draft_exec_config.is_some() { + // make sure this is set to if using draft model + p.enable_kv_cache_reuse = True; + p.n_draft_tokens = Some(5) + } + log::info!("Initializing executor with config: {:?}", exec_config); if cli_config.test_py { @@ -205,7 +219,7 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { return Ok(()); } - let (executor, tok_env, chat_builder) = AsyncExecutor::new(&cli_config, &config, exec_config)?; + let (executor, tok_env, chat_builder) = AsyncExecutor::new(&cli_config, &config, exec_config, draft_exec_config)?; // we only get here on rank 0 @@ -254,6 +268,7 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { if state.py_state.enabled { log::info!("Skipping warmup due to python"); } else { + // TODO draft executor warmup call here // warmup request log::info!("Warming up executor"); let mut warmup_tokens = @@ -269,6 +284,7 @@ pub async fn run_server(mut cli_config: CliConfig) -> anyhow::Result<()> { client_req_id: ClientReqId::new(1), lora_params: None, is_run: false, + draft_params: None // TODO fill out }, None, vec![], diff --git a/trtllm-c/main.cpp b/trtllm-c/main.cpp index 1cd0d24..7def46d 100644 --- a/trtllm-c/main.cpp +++ b/trtllm-c/main.cpp @@ -23,7 +23,7 @@ catch (...) \ { \ return strdup("Unknown exception."); \ - } \ + } f \ return nullptr namespace tle = tensorrt_llm::executor; @@ -34,6 +34,7 @@ struct ResponseData std::string error; tle::VecTokens tokens; tle::VecLogProbs logprobs; + tle::Tensor logitsTensor; }; struct TlcExecutor @@ -150,6 +151,16 @@ tle::Shape _tlc_to_tle_shape(TlcShape tlc_shape) return tle::Shape(tlc_shape.dims, tlc_shape.num_dims); } +TlcShape _tle_to_tlc_shape(const tle::Shape& tle_shape) +{ + TlcShape tlc_shape; + tlc_shape.num_dims = std::min(tle_shape.size(), static_cast(TLC_MAX_SHAPE)); // Ensure within max shape + + std::memcpy(tlc_shape.dims, tle_shape.data(), tlc_shape.num_dims * sizeof(int64_t)); + + return tlc_shape; +} + static tle::DataType to_tle_datatype(TlcDataType t) { switch (t) @@ -167,6 +178,23 @@ static tle::DataType to_tle_datatype(TlcDataType t) } } +static TlcDataType to_tlc_datatype(tle::DataTypee t) +{ + switch (t) + { + case tle::DataType::kBOOL: return TLC_DT_BOOL; + case tle::DataType::kUINT8: return TLC_DT_U8; + case tle::DataType::kINT8: return TLC_DT_I8; + case tle::DataType::kINT32: return TLC_DT_I32; + case tle::DataType::kINT64: return TLC_DT_I64; + case tle::DataType::kBF16: return TLC_DT_BF16; + case tle::DataType::kFP8: return TLC_DT_F8; + case tle::DataType::kFP16: return TLC_DT_F16; + case tle::DataType::kFP32: return TLC_DT_F32; + default: throw std::runtime_error("Unsupported data type"); + } +} + static nvinfer1::DataType to_nvinfer_datatype(TlcDataType t) { static_assert((int) TLC_DT_F32 == (int) nvinfer1::DataType::kFLOAT); @@ -238,6 +266,13 @@ TlcStatus tlc_enqueue_request(TlcExecutor* ctx, TlcRequest const* request, TlcRe tle::SamplingConfig samplingConfig; + // ** code from trtllm, need to map DraftParams to it below + // think target data types are tle::VecTokens tokens; tle::VecLogProbs logprobs; + // TODO how spec decoding called? + // tle::ExternalDraftTokensConfig draftTokensConfig( + // std::move(draftTokens), logitsTensor, std::nullopt /* acceptance threshold */, runtimeOpts.fastLogits); + // request.setExternalDraftTokensConfig(draftTokensConfig); + auto const& p = request->params; auto const& pp = request->prompt_params; @@ -327,6 +362,19 @@ TlcStatus tlc_enqueue_request(TlcExecutor* ctx, TlcRequest const* request, TlcRe req.setLoraConfig(loraConfig); } + // If we have draft params build draft config + if (request->draft_params.draft_tokens && request->draft_params.logits_tensor.data_ptr) + { + auto const& dp = request->draft_params; + tle::Tensor logitsTensor; + logitsTensor = _tlc_to_tle_tensor(dp.logits_tensor); + assert(dp.num_tokens > 0); + tle::VecTokens draftTokens(dp.draft_tokens, dp.draft_tokens + dp.num_tokens); + tle::ExternalDraftTokensConfig draftTokensConfig( + std::move(draftTokens), logitsTensor, std::nullopt, std::nullopt); + request.setExternalDraftTokensConfig(draftTokensConfig); + } + std::vector requests; requests.emplace_back(std::move(req)); auto ids = ctx->executor.enqueueRequests(std::move(requests)); @@ -399,6 +447,14 @@ TlcStatus tlc_await_responses( } assert(result.outputTokenIds.size() == 1); resp_data.tokens = result.outputTokenIds.at(0); + + // Grab generationLogits, TODO=need to see if nonstreaming/streaming matters here + auto generationLogits = result.generationLogits.value(); + auto logitsShape = generationLogits.getShape(); + assert(logitsShape[0] == 1); + resp_data.logitsTensor = tle::Tensor::cpu(generationLogits.getDataType(), {logitsShape[1], logitsShape[2]}) + std::memcpy(logitsTensor.getData(), generationLogits.getData(), generationLogits.getSizeInBytes()); + if (result.logProbs.has_value()) { assert(result.logProbs->size() == 1); @@ -430,6 +486,9 @@ TlcStatus tlc_await_responses( c_resp.num_logprobs = data.logprobs.size(); if (c_resp.num_logprobs > 0) c_resp.logprobs = data.logprobs.data(); + c_resp.logits_tensor.data_type = to_tlc_datatype(data.logitsTensor.getDataType()); + c_resp.logits_tensor.data_ptr = resp_data.logitsTensor.getData(); + c_resp.logits_tensor.shape = _tle_to_tlc_shape(resp_data.logitsTensor.getShape()); } ctx->responses.emplace_back(c_resp); diff --git a/trtllm-c/tlc.h b/trtllm-c/tlc.h index 5930864..16e6bc3 100644 --- a/trtllm-c/tlc.h +++ b/trtllm-c/tlc.h @@ -136,6 +136,13 @@ extern "C" TlcTensor config; } TlcLoraParams; + typedef struct + { + int32_t* draft_tokens; + uint32_t num_tokens; + TlcTensor logits_tensor; + } TlcDraftParams; + typedef struct { bool use_logits_post_processor; @@ -182,6 +189,7 @@ extern "C" TlcRequestParams params; TlcLoraParams lora_params; TlcPromptParams prompt_params; + TlcDraftParams draft_params; } TlcRequest; /// @brief The reason why the model stopped generating tokens for a request. @@ -215,6 +223,7 @@ extern "C" int32_t const* tokens; uint32_t num_logprobs; float const* logprobs; + TlcTensor generation_logits; } TlcResponse; typedef struct TlcExecutor TlcExecutor; diff --git a/trtllm_rs/src/tlc.rs b/trtllm_rs/src/tlc.rs index 6330006..811c899 100644 --- a/trtllm_rs/src/tlc.rs +++ b/trtllm_rs/src/tlc.rs @@ -30,6 +30,7 @@ pub struct ResponseChunk { pub tokens: Vec, pub logprobs: Option>>, pub is_req_final: bool, + pub generation_logits: Option, } #[derive(Debug, Clone, PartialEq, Eq)] @@ -138,6 +139,14 @@ pub struct LoraParams { pub config: Option, } + +#[derive(Debug, Clone, Default)] +pub struct DraftParams { + pub draft_tokens: Vec, // TODO needs to match vec token + pub num_tokens: u32, + pub logits_tensor: Option, // needs +} + #[derive(Debug, Clone)] pub struct RequestInit { pub tokens: Vec, @@ -145,6 +154,7 @@ pub struct RequestInit { pub is_run: bool, pub params: RequestParams, pub lora_params: Option, + pub draft_params: Option // logits, tokens, acceptance ratio } unsafe impl Send for ffi::TlcPromptParams {} @@ -236,6 +246,10 @@ impl ffi::TlcShape { } r } + + pub fn to_vec(&self) -> Vec { + self.dims[..self.num_dims].to_vec() + } } impl Default for ffi::TlcTensor { @@ -273,6 +287,23 @@ impl Tensor { data_type: tensor.dtype, } } + + //TODO: Need to account for multiple datatypes + pub fn from_tlc_tensor(tlc_tensor: &ffi::TlcTensor) -> Self { + let shape = tlc_tensor.shape.to_vec(); + + let data_ptr = tlc_tensor.data_ptr as *const f32; + let num_elements: usize = shape.iter().product(); + let data: Vec = unsafe { + std::slice::from_raw_parts(data_ptr, num_elements).to_vec() + }; + + Tensor { + size: shape, + data, + dtype: tlc_tensor.data_type, + } + } } impl Executor { @@ -339,6 +370,19 @@ impl Executor { arg.lora_params = lp; } + // Load draft params in request if we have any. + if let Some(draft_params) = &init.draft_params { + let mut dp = ffi::TlcDraftParams { + draft_tokens: draft_params.tokens.as_ptr() as *mut i32, + num_tokens: draft_params.tokens.len() as u32, + logits_tensor: ffi::TlcTensor::default(), + }; + if let Some(logits) = &draft_params.logits_tensor { + dp.logits_tensor = logits.as_tlc_tensor(); + } + arg.draft_params = dp; + } + let mut req_id = 0; let err = unsafe { ffi::tlc_enqueue_request(self.inner, &arg, &mut req_id) }; map_err(err, ReqId(req_id), "tlc_enqueue_request") @@ -454,6 +498,7 @@ impl Responder { }, logprobs, tokens, + generation_logits: Tensor::from_tlc_tensor(resp.generation_logits) } }) .collect())