Skip to content

Commit 0f2fbd8

Browse files
committed
bugfix: fix runtime error for mlu.
1 parent a533db2 commit 0f2fbd8

File tree

4 files changed

+17
-2
lines changed

4 files changed

+17
-2
lines changed

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,14 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
557557
input_params.decode_seq_range =
558558
util::find_ones_indices(input_params.q_seq_lens_vec);
559559

560+
// for flashinfer
561+
input_params.paged_kv_indptr =
562+
torch::tensor(state_.paged_kv_indptr, torch::kInt);
563+
input_params.paged_kv_indices =
564+
torch::tensor(state_.paged_kv_indices, torch::kInt);
565+
input_params.paged_kv_last_page_len =
566+
torch::tensor(state_.paged_kv_last_page_len, torch::kInt);
567+
560568
// Setup multimodal data
561569
input_params.mm_data = MMData::batch(mm_data_vec_);
562570

@@ -631,6 +639,12 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
631639
raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos);
632640
raw_forward_input.prefill_seq_len = state_.prefill_seq_len;
633641

642+
// for flashinfer
643+
raw_forward_input.paged_kv_indptr = std::move(state_.paged_kv_indptr);
644+
raw_forward_input.paged_kv_indices = std::move(state_.paged_kv_indices);
645+
raw_forward_input.paged_kv_last_page_len =
646+
std::move(state_.paged_kv_last_page_len);
647+
634648
raw_forward_input.embedding_ids = std::move(state_.embedding_ids);
635649
raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids);
636650

xllm/core/kernels/ops_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void batch_decode(AttentionParams& params) {
139139
mlu::batch_decode(params.query,
140140
params.k_cache,
141141
params.output,
142-
params.block_table,
142+
params.block_table.value(),
143143
params.kv_seq_lens,
144144
params.v_cache,
145145
params.output_lse,

xllm/core/kernels/param.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct AttentionParams {
7272
std::optional<torch::Tensor> alibi_slope;
7373
std::optional<torch::Tensor> q_quant_scale;
7474
std::optional<torch::Tensor> out_quant_scale;
75-
torch::Tensor block_table;
75+
std::optional<torch::Tensor> block_table;
7676
std::string compute_dtype;
7777
int max_seq_len;
7878
int window_size_left;

xllm/core/layers/mlu/attention.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
139139

140140
// for mlu
141141
attention_params.block_table = attn_metadata.block_table;
142+
attention_params.kv_seq_lens = attn_metadata.kv_seq_lens;
142143

143144
// for flashinfer
144145
attention_params.paged_kv_indptr = attn_metadata.paged_kv_indptr;

0 commit comments

Comments
 (0)