We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent a5fa4d8 commit 1e1fba9Copy full SHA for 1e1fba9
xllm/core/kernels/mlu/attention.cpp
@@ -32,7 +32,7 @@ void batch_prefill(const torch::Tensor& query,
32
const torch::Tensor& key,
33
const torch::Tensor& value,
34
torch::Tensor& output,
35
- torch::Tensor& output_lse,
+ std::optional<torch::Tensor>& output_lse,
36
const std::optional<torch::Tensor>& query_start_loc,
37
const std::optional<torch::Tensor>& seq_start_loc,
38
const std::optional<torch::Tensor>& alibi_slope,
@@ -80,7 +80,7 @@ void batch_decode(const torch::Tensor& query,
80
const torch::Tensor& block_table,
81
const torch::Tensor& seq_lens,
82
const torch::Tensor& v_cache,
83
84
const std::optional<torch::Tensor>& q_quant_scale,
85
const std::optional<torch::Tensor>& k_cache_quant_scale,
86
const std::optional<torch::Tensor>& v_cache_quant_scale,
xllm/core/kernels/mlu/mlu_ops_api.h
@@ -62,7 +62,7 @@ void batch_prefill(const torch::Tensor& query,
62
63
64
65
66
67
68
@@ -87,7 +87,7 @@ void batch_decode(const torch::Tensor& query,
87
88
89
90
91
92
93
xllm/core/kernels/ops_api.cpp
@@ -71,12 +71,11 @@ void reshape_paged_cache(ReshapePagedCacheParams& params) {
71
72
void batch_prefill(AttentionParams& params) {
73
#if defined(USE_MLU)
74
- torch::Tensor lse = params.output_lse.value_or(torch::Tensor());
75
mlu::batch_prefill(params.query,
76
params.key,
77
params.value,
78
params.output,
79
- lse,
+ params.output_lse,
params.query_start_loc,
params.seq_start_loc,
params.alibi_slope,
@@ -94,7 +93,6 @@ void batch_prefill(AttentionParams& params) {
94
params.window_size_right,
95
params.compute_dtype,
96
params.return_lse);
97
- params.output_lse = lse;
98
#elif defined(USE_CUDA)
99
throw std::runtime_error("batch_prefill for cuda not implemented");
100
#else
@@ -104,14 +102,13 @@ void batch_prefill(AttentionParams& params) {
104
102
105
103
void batch_decode(AttentionParams& params) {
106
107
108
mlu::batch_decode(params.query,
109
params.k_cache,
110
111
params.block_table,
112
params.seq_lens,
113
params.v_cache,
114
115
params.q_quant_scale,
116
params.k_cache_quant_scale,
117
params.v_cache_quant_scale,
@@ -125,7 +122,6 @@ void batch_decode(AttentionParams& params) {
125
122
params.scale,
126
123
params.return_lse,
127
124
params.kv_cache_quant_bit_size);
128
129
130
throw std::runtime_error("batch_decode for cuda not implemented");
131
0 commit comments