Skip to content

Commit efc9e12

Browse files
jiafatomfs-eire
authored andcommitted
Fix Memory Issue sparse_attention Rotary (#26278)
### Description From an internal user, we see that sparse attention has similar memory issue of #22290 So we follow that PR to make the change. ### Motivation and Context SparseAttention memory issue.
1 parent abb37be commit efc9e12

File tree

1 file changed

+7
-11
lines changed

1 file changed

+7
-11
lines changed

onnxruntime/contrib_ops/cpu/sparse/sparse_attention.cc

Lines changed: 7 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
130130
allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
131131
}
132132

133+
OrtValue RotaryQKV;
134+
OrtValue RotaryQ;
135+
OrtValue RotaryK;
136+
T* q_rotary = Q.GetMutable<Tensor>()->MutableData<T>();
137+
T* k_rotary = packed_qkv ? nullptr : K.GetMutable<Tensor>()->MutableData<T>();
133138
if (do_rotary_) {
134139
rotary_embedding_helper::RotaryParameters rotary_params = {};
135140
rotary_params.batch_size = batch_size;
@@ -167,30 +172,22 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
167172

168173
const T* q_input;
169174
const T* k_input;
170-
T* q_rotary;
171-
T* k_rotary;
172175
if (packed_qkv) {
173-
OrtValue RotaryQKV;
174176
TensorShape qkv_shape({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size});
175177
Tensor::InitOrtValue(element_type, qkv_shape, allocator, RotaryQKV);
176178
q_input = Q.Get<Tensor>().Data<T>();
177179
k_input = q_input + num_heads_ * sequence_length * head_size;
178180
q_rotary = RotaryQKV.GetMutable<Tensor>()->MutableData<T>();
179181
k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
180-
Q = RotaryQKV;
181182
} else {
182-
OrtValue RotaryQ;
183183
TensorShape q_shape({batch_size, num_heads_, sequence_length, head_size});
184184
Tensor::InitOrtValue(element_type, q_shape, allocator, RotaryQ);
185-
OrtValue RotaryK;
186185
TensorShape k_shape({batch_size, kv_num_heads_, sequence_length, head_size});
187186
Tensor::InitOrtValue(element_type, k_shape, allocator, RotaryK);
188187
q_input = Q.Get<Tensor>().Data<T>();
189188
k_input = K.Get<Tensor>().Data<T>();
190189
q_rotary = RotaryQ.GetMutable<Tensor>()->MutableData<T>();
191190
k_rotary = RotaryK.GetMutable<Tensor>()->MutableData<T>();
192-
Q = RotaryQ;
193-
K = RotaryK;
194191
}
195192

196193
ORT_RETURN_IF_ERROR(RunRotaryEmbedding<T>(tp, rotary_params, q_input,
@@ -221,9 +218,8 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
221218

222219
ORT_RETURN_IF_ERROR(context->GetTempSpaceAllocator(&allocator));
223220
// Compute the attention score and apply the score to V
224-
return ApplyAttention(Q.Get<Tensor>().Data<T>(), packed_qkv ? nullptr : K.Get<Tensor>().Data<T>(),
225-
packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(), past_key, past_value,
226-
output, present_key, present_value,
221+
return ApplyAttention(q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get<Tensor>().Data<T>(),
222+
past_key, past_value, output, present_key, present_value,
227223
total_key_lengths, block_row_indices, block_col_indices, parameters, allocator, context);
228224
}
229225
} // namespace contrib

0 commit comments

Comments
 (0)