@@ -130,6 +130,11 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
130130 allocator, batch_size, kv_num_heads_, sequence_length, head_size, value, V));
131131 }
132132
133+ OrtValue RotaryQKV;
134+ OrtValue RotaryQ;
135+ OrtValue RotaryK;
136+ T* q_rotary = Q.GetMutable <Tensor>()->MutableData <T>();
137+ T* k_rotary = packed_qkv ? nullptr : K.GetMutable <Tensor>()->MutableData <T>();
133138 if (do_rotary_) {
134139 rotary_embedding_helper::RotaryParameters rotary_params = {};
135140 rotary_params.batch_size = batch_size;
@@ -167,30 +172,22 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
167172
168173 const T* q_input;
169174 const T* k_input;
170- T* q_rotary;
171- T* k_rotary;
172175 if (packed_qkv) {
173- OrtValue RotaryQKV;
174176 TensorShape qkv_shape ({batch_size, num_heads_ + 2 * kv_num_heads_, sequence_length, head_size});
175177 Tensor::InitOrtValue (element_type, qkv_shape, allocator, RotaryQKV);
176178 q_input = Q.Get <Tensor>().Data <T>();
177179 k_input = q_input + num_heads_ * sequence_length * head_size;
178180 q_rotary = RotaryQKV.GetMutable <Tensor>()->MutableData <T>();
179181 k_rotary = q_rotary + num_heads_ * sequence_length * head_size;
180- Q = RotaryQKV;
181182 } else {
182- OrtValue RotaryQ;
183183 TensorShape q_shape ({batch_size, num_heads_, sequence_length, head_size});
184184 Tensor::InitOrtValue (element_type, q_shape, allocator, RotaryQ);
185- OrtValue RotaryK;
186185 TensorShape k_shape ({batch_size, kv_num_heads_, sequence_length, head_size});
187186 Tensor::InitOrtValue (element_type, k_shape, allocator, RotaryK);
188187 q_input = Q.Get <Tensor>().Data <T>();
189188 k_input = K.Get <Tensor>().Data <T>();
190189 q_rotary = RotaryQ.GetMutable <Tensor>()->MutableData <T>();
191190 k_rotary = RotaryK.GetMutable <Tensor>()->MutableData <T>();
192- Q = RotaryQ;
193- K = RotaryK;
194191 }
195192
196193 ORT_RETURN_IF_ERROR (RunRotaryEmbedding<T>(tp, rotary_params, q_input,
@@ -221,9 +218,8 @@ Status SparseAttention<T>::Compute(OpKernelContext* context) const {
221218
222219 ORT_RETURN_IF_ERROR (context->GetTempSpaceAllocator (&allocator));
223220 // Compute the attention score and apply the score to V
224- return ApplyAttention (Q.Get <Tensor>().Data <T>(), packed_qkv ? nullptr : K.Get <Tensor>().Data <T>(),
225- packed_qkv ? nullptr : V.Get <Tensor>().Data <T>(), past_key, past_value,
226- output, present_key, present_value,
221+ return ApplyAttention (q_rotary, packed_qkv ? nullptr : k_rotary, packed_qkv ? nullptr : V.Get <Tensor>().Data <T>(),
222+ past_key, past_value, output, present_key, present_value,
227223 total_key_lengths, block_row_indices, block_col_indices, parameters, allocator, context);
228224}
229225} // namespace contrib
0 commit comments