Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions ktransformers/models/custom_modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,22 @@ def __init__(

def init_wrapper(self, use_cuda_graph, device, max_batch_size, max_pages):
self.use_cuda_graph = use_cuda_graph
self.workspace_buffer = torch.empty(128 * 1024 * 1024, dtype=torch.int8).to(0)
# Increase buffer sizes to be safe
self.workspace_buffer = torch.empty(256 * 1024 * 1024, dtype=torch.int8).to(0)
self.qo_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
self.paged_kv_indptr_buf = torch.empty((max_batch_size+2,), dtype=torch.int32, device=device)
self.paged_kv_indices_buf = torch.empty((max_pages,), dtype=torch.int32, device=device)
# Make sure this buffer is large enough
self.paged_kv_indices_buf = torch.empty((max_pages * 2,), dtype=torch.int32, device=device)
self.paged_kv_len_buf = torch.empty((max_batch_size+1,), dtype=torch.int32, device=device)
self.bsz_tensor_buf = torch.empty((1, ), dtype=torch.int32, device=device)


self.wrapper = flashinfer.mla.BatchMLAPagedAttentionWrapper(
self.workspace_buffer, use_cuda_graph=use_cuda_graph,
qo_indptr=self.qo_indptr_buf,kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,kv_len_arr=self.paged_kv_len_buf,
qo_indptr=self.qo_indptr_buf,
kv_indptr=self.paged_kv_indptr_buf,
kv_indices=self.paged_kv_indices_buf,
kv_len_arr=self.paged_kv_len_buf,
bsz_tensor=self.bsz_tensor_buf,
backend = "fa2",
)
Expand Down Expand Up @@ -145,4 +149,4 @@ def flash_infer_attn_plan(self, batch: ForwardBatchInput, bsz_tensors, num_token
minibatch = batch.minibatch
self.wrapper.plan(minibatch.q_indptr, minibatch.kv_indptr, minibatch.kv_indices,
minibatch.kv_len, num_heads, head_dim_ckv, head_dim_kpe, page_size, causal, sm_scale, q_data_type, kv_data_type, bsz_tensors)


22 changes: 16 additions & 6 deletions ktransformers/server/balance_serve/inference/sampling/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def __init__(self, bsz = 1, device = torch.device('cuda'), pretrained_config:Gen
self.is_all_greedy = False

class Sampler(nn.Module):
def __init__(self):
def __init__(self, device=torch.device('cuda')):
super().__init__()
self.device = device

def forward(
self,
Expand All @@ -63,15 +64,20 @@ def forward(
if sampling_config == None:
sampling_config = SamplingOptions()

logits = logits.contiguous()
# Ensure all tensors are on the same device
device = logits.device
logits = logits.contiguous().to(device)
sampling_config.temperatures = sampling_config.temperatures.to(device)

origin_logits = logits.clone()
if sampling_config.is_all_greedy:
# Use torch.argmax if all requests use greedy sampling
probs = logits
batch_next_token_ids = torch.argmax(logits, -1)
else:
# Post process logits
logits.div_(sampling_config.temperatures)
safe_temperatures = sampling_config.temperatures.masked_fill(sampling_config.temperatures == 0, 1.0)
logits.div_(safe_temperatures)
max_top_k_round, batch_size = 32, logits.shape[0]
if sampling_config.need_min_p_sampling:
probs = torch.softmax(logits, dim=-1)
Expand All @@ -82,8 +88,10 @@ def forward(
batch_next_token_ids = min_p_sampling_from_probs(
probs, sampling_config.min_ps
)
torch.cuda.synchronize()
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
if temperature_0_idx.numel() > 0:
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
else:
# TODO: use different kernel when don't need top_k or top_p
# @TODO get probs
Expand All @@ -94,7 +102,9 @@ def forward(
sampling_config.top_ps,
filter_apply_order="joint",
)
torch.cuda.synchronize()
temperature_0_idx = torch.where(sampling_config.temperatures == 0)[0]
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)
if temperature_0_idx.numel() > 0:
batch_next_token_ids[temperature_0_idx] = torch.argmax(origin_logits[temperature_0_idx], -1).to(torch.int32)

return batch_next_token_ids.to(torch.int32), probs
return batch_next_token_ids.to(torch.int32), probs