Skip to content

Commit e8024bc

Browse files
committed
bugfix: fix runtime error for mlu.
1 parent a533db2 commit e8024bc

File tree

3 files changed

+3
-2
lines changed

3 files changed

+3
-2
lines changed

xllm/core/kernels/ops_api.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ void batch_decode(AttentionParams& params) {
139139
mlu::batch_decode(params.query,
140140
params.k_cache,
141141
params.output,
142-
params.block_table,
142+
params.block_table.value(),
143143
params.kv_seq_lens,
144144
params.v_cache,
145145
params.output_lse,

xllm/core/kernels/param.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ struct AttentionParams {
7272
std::optional<torch::Tensor> alibi_slope;
7373
std::optional<torch::Tensor> q_quant_scale;
7474
std::optional<torch::Tensor> out_quant_scale;
75-
torch::Tensor block_table;
75+
std::optional<torch::Tensor> block_table;
7676
std::string compute_dtype;
7777
int max_seq_len;
7878
int window_size_left;

xllm/core/layers/mlu/attention.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -139,6 +139,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> AttentionImpl::forward(
139139

140140
// for mlu
141141
attention_params.block_table = attn_metadata.block_table;
142+
attention_params.kv_seq_lens = attn_metadata.kv_seq_lens;
142143

143144
// for flashinfer
144145
attention_params.paged_kv_indptr = attn_metadata.paged_kv_indptr;

0 commit comments

Comments
 (0)