Skip to content

Commit 1e1fba9

Browse files
committed
refactor: update output_lse parameter to use std::optional in batch_prefill and batch_decode functions.
1 parent a5fa4d8 commit 1e1fba9

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

xllm/core/kernels/mlu/attention.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ void batch_prefill(const torch::Tensor& query,
3232
const torch::Tensor& key,
3333
const torch::Tensor& value,
3434
torch::Tensor& output,
35-
torch::Tensor& output_lse,
35+
std::optional<torch::Tensor>& output_lse,
3636
const std::optional<torch::Tensor>& query_start_loc,
3737
const std::optional<torch::Tensor>& seq_start_loc,
3838
const std::optional<torch::Tensor>& alibi_slope,
@@ -80,7 +80,7 @@ void batch_decode(const torch::Tensor& query,
8080
const torch::Tensor& block_table,
8181
const torch::Tensor& seq_lens,
8282
const torch::Tensor& v_cache,
83-
torch::Tensor& output_lse,
83+
std::optional<torch::Tensor>& output_lse,
8484
const std::optional<torch::Tensor>& q_quant_scale,
8585
const std::optional<torch::Tensor>& k_cache_quant_scale,
8686
const std::optional<torch::Tensor>& v_cache_quant_scale,

xllm/core/kernels/mlu/mlu_ops_api.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ void batch_prefill(const torch::Tensor& query,
6262
const torch::Tensor& key,
6363
const torch::Tensor& value,
6464
torch::Tensor& output,
65-
torch::Tensor& output_lse,
65+
std::optional<torch::Tensor>& output_lse,
6666
const std::optional<torch::Tensor>& query_start_loc,
6767
const std::optional<torch::Tensor>& seq_start_loc,
6868
const std::optional<torch::Tensor>& alibi_slope,
@@ -87,7 +87,7 @@ void batch_decode(const torch::Tensor& query,
8787
const torch::Tensor& block_table,
8888
const torch::Tensor& seq_lens,
8989
const torch::Tensor& v_cache,
90-
torch::Tensor& output_lse,
90+
std::optional<torch::Tensor>& output_lse,
9191
const std::optional<torch::Tensor>& q_quant_scale,
9292
const std::optional<torch::Tensor>& k_cache_quant_scale,
9393
const std::optional<torch::Tensor>& v_cache_quant_scale,

xllm/core/kernels/ops_api.cpp

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,12 +71,11 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) {
7171

7272
void batch_prefill(AttentionParams& params) {
7373
#if defined(USE_MLU)
74-
torch::Tensor lse = params.output_lse.value_or(torch::Tensor());
7574
mlu::batch_prefill(params.query,
7675
params.key,
7776
params.value,
7877
params.output,
79-
lse,
78+
params.output_lse,
8079
params.query_start_loc,
8180
params.seq_start_loc,
8281
params.alibi_slope,
@@ -94,7 +93,6 @@ void batch_prefill(AttentionParams& params) {
9493
params.window_size_right,
9594
params.compute_dtype,
9695
params.return_lse);
97-
params.output_lse = lse;
9896
#elif defined(USE_CUDA)
9997
throw std::runtime_error("batch_prefill for cuda not implemented");
10098
#else
@@ -104,14 +102,13 @@ void batch_prefill(AttentionParams& params) {
104102

105103
void batch_decode(AttentionParams& params) {
106104
#if defined(USE_MLU)
107-
torch::Tensor lse = params.output_lse.value_or(torch::Tensor());
108105
mlu::batch_decode(params.query,
109106
params.k_cache,
110107
params.output,
111108
params.block_table,
112109
params.seq_lens,
113110
params.v_cache,
114-
lse,
111+
params.output_lse,
115112
params.q_quant_scale,
116113
params.k_cache_quant_scale,
117114
params.v_cache_quant_scale,
@@ -125,7 +122,6 @@ void batch_decode(AttentionParams& params) {
125122
params.scale,
126123
params.return_lse,
127124
params.kv_cache_quant_bit_size);
128-
params.output_lse = lse;
129125
#elif defined(USE_CUDA)
130126
throw std::runtime_error("batch_decode for cuda not implemented");
131127
#else

0 commit comments

Comments
 (0)