Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 96 additions & 37 deletions llgtrt/src/async_exec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,7 @@ use std::{
use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
use toktrie::{SimpleVob, TokEnv};
use trtllm_rs::{
ClientReqId, Executor, ExecutorInit, MaskAllocator, ReqId, RequestInit, ResponseChunk,
TlcLogitsEntry,
ClientReqId, Executor, ExecutorInit, MaskAllocator, ReqId, RequestInit, Responder, ResponseChunk, TlcLogitsEntry
};

use crate::{
Expand Down Expand Up @@ -78,10 +77,12 @@ impl Display for ReqData {

pub struct AsyncExecutor {
executor: Executor,
draft_executor: Option<Executor>,
n_vocab: usize,
max_batch_size: usize,
req_to_client: HashMap<ReqId, ClientReqId>,
req_data: HashMap<ClientReqId, ReqData>,
n_draft_tokens: u32, // TODO prob better location for this, esp if dynamically setting between calls
}

static mut GLOBAL_ALLOCATOR: *const MaskAllocator = ptr::null();
Expand Down Expand Up @@ -364,16 +365,35 @@ impl AsyncExecutor {
self.executor.cancel_request(req_id)
}

pub fn has_draft_model(&self) -> bool {
self.draft_executor.is_some()
}

pub fn n_draft_tokens(&self) -> u32 {
self.n_draft_tokens
}

pub fn new(
cli_config: &CliConfig,
config: &LlgTrtConfig,
mut executor_init: ExecutorInit,
mut draft_executor_init: Option<ExecutorInit>,
) -> Result<(Self, TokEnv, ChatBuilder)> {
executor_init.logits_callback = Some(logits_processor);
let max_batch_size = executor_init.trt_params.max_batch_size as usize;
let mut max_batch_size = executor_init.trt_params.max_batch_size as usize;
log::info!("new executor: max_batch_size={max_batch_size}");
let (executor, mut responder) = Executor::new(executor_init)?;

let (draft_executor, mut draft_responder) = if draft_executor_init.is_some() {
draft_executor_init.logits_callback = Some(logits_processor);
max_batch_size = draft_executor_init.trt_params.max_batch_size as usize;
log::info!("new draft executor: max_batch_size={max_batch_size}");
let (executor, responder) = Executor::new(draft_executor_init)?;
(Some(executor), Some(responder))
} else {
(None, None)
};

// on non-0 ranks, this will just wait until the rank 0 exits and then exit the process
executor.check_mpi();

Expand All @@ -384,62 +404,102 @@ impl AsyncExecutor {

let res = Self {
executor,
draft_executor,
req_data: HashMap::new(),
req_to_client: HashMap::new(),
n_vocab,
max_batch_size,
};
rayon::spawn(move || loop {
let resps = responder
.await_responses(std::time::Duration::from_millis(1))
.unwrap();

if resps.len() == 0 {
continue;
}
// TODO idk what this does rn, need to setup with draft_responder?
let receive_from_responder = |responder: Responder| -> FnOnce {
return || {
move || loop {
let resps = responder
.await_responses(std::time::Duration::from_millis(1))
.unwrap();

let mut exec = AsyncExecutor::lock();
for resp in resps {
let req_id = resp.req_id;
if let Some(client_req_id) = exec.req_to_client.get(&req_id) {
let client_req_id = *client_req_id;
let rd = exec.req_data.get_mut(&client_req_id).unwrap();
let is_req_final = resp.is_req_final;
let idx = resp.sequence_idx as usize;

let mut r = StepResults {
response: resp,
logs: std::mem::take(&mut rd.logs),
final_llg: None,
};
if rd.llgs.len() > 0 && r.response.finish_reason.is_some() {
r.final_llg = std::mem::take(&mut rd.llgs[idx]);
if resps.len() == 0 {
continue;
}
if rd.tx.send(r).is_err() {
log::warn!("connection dropped; req={}", req_id);
let _ = exec.cancel_request(req_id);
} else if is_req_final {
// no more data coming from here
exec.drop_request_data(req_id);

let mut exec = AsyncExecutor::lock();
for resp in resps {
let req_id = resp.req_id;
if let Some(client_req_id) = exec.req_to_client.get(&req_id) {
let client_req_id = *client_req_id;
let rd = exec.req_data.get_mut(&client_req_id).unwrap();
let is_req_final = resp.is_req_final;
let idx = resp.sequence_idx as usize;

let mut r = StepResults {
response: resp,
logs: std::mem::take(&mut rd.logs),
final_llg: None,
};
if rd.llgs.len() > 0 && r.response.finish_reason.is_some() {
r.final_llg = std::mem::take(&mut rd.llgs[idx]);
}
if rd.tx.send(r).is_err() {
log::warn!("connection dropped; req={}", req_id);
let _ = exec.cancel_request(req_id);
} else if is_req_final {
// no more data coming from here
exec.drop_request_data(req_id);
}
} else {
log::warn!("Response for unknown request: {:?}", req_id);
let _ = exec.executor.cancel_request(req_id);
}
}
} else {
log::warn!("Response for unknown request: {:?}", req_id);
let _ = exec.executor.cancel_request(req_id);
}
}
});
};

rayon::spawn(receive_from_responder(responder)());

if Some(x) = draft_responder {
rayon::spawn(receive_from_responder(x)());
}

Ok((res, tok_env, chat_builder))
}

pub fn can_enqueue_request(&self) -> bool {
self.executor.can_enqueue_request()
}

pub fn add_draft_request(
&mut self,
init: &RequestInit,
prompt_params: Option<Arc<PyPromptParams>>,
llgs: Vec<Box<Constraint>>,
) -> Result<(ReqId, UnboundedReceiver<StepResults>)> {
debug_assert!(self.draft_executor.is_some());
init.params.
self.add_request_to_executor(
self.draft_executor.expect("this should've been checked before calling this method"),
init,
prompt_params,
llgs
)
}

pub fn add_request(
&mut self,
init: &RequestInit,
prompt_params: Option<Arc<PyPromptParams>>,
llgs: Vec<Box<Constraint>>,
) -> Result<(ReqId, UnboundedReceiver<StepResults>)> {
self.add_request_to_executor(self.executor, init, prompt_params, llgs)
}

fn add_request_to_executor(
&mut self,
executor: Executor,
init: &RequestInit,
prompt_params: Option<Arc<PyPromptParams>>,
llgs: Vec<Box<Constraint>>,
) -> Result<(ReqId, UnboundedReceiver<StepResults>)> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

Expand All @@ -451,7 +511,6 @@ impl AsyncExecutor {

let pp = prompt_params.as_ref().map(|p| &p.tlc_prompt_params);

// we're locked here, so it's safe to insert only after enqueuing
let req_id = self.executor.enqueue_request(init, pp)?;
self.req_data.insert(
client_req_id,
Expand Down
9 changes: 9 additions & 0 deletions llgtrt/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ impl Default for TrtLlmRuntimeConfig {
cross_kv_cache_fraction: None,
secondary_offload_min_priority: None,
event_buffer_max_size: None,
n_draft_tokens: None
}
}
}
Expand Down Expand Up @@ -158,6 +159,14 @@ pub struct CliConfig {
#[arg(long, short = 'T')]
pub tokenizer: Option<String>,

/// Path to a compiled TensorRT-LLM draft engine
#[arg(long)]
pub draft_engine: Option<String>,

/// Path to folder with HF tokenizer.json and tokenizer_config.json files for draft engine; defaults to ---draft-engine
#[arg(long)]
pub draft_tokenizer: Option<String>,

/// Debug output
#[arg(long, short = 'd')]
pub debug: bool,
Expand Down
2 changes: 1 addition & 1 deletion llgtrt/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async fn main() -> anyhow::Result<()> {
// for every request when running with mpirun
// env::set_var("TLLM_LOG_LEVEL", "INFO");
}
}
}

llgtrt::logging::init_log(llgtrt::logging::LogMode::Normal)?;

Expand Down
Loading