Skip to content

Commit 274d162

Browse files
Fix SparseAttention cos/sin cache dimension checks (#20609)
### Description This PR fixes the dimension checks for the cos/sin caches used in the rotary embeddings in the `SparseAttention` operator. ### Motivation and Context This PR ports over the same changes from [this PR](#20547) for `GroupQueryAttention`.
1 parent 58d7b12 commit 274d162

File tree

1 file changed

+4
-4
lines changed

1 file changed

+4
-4
lines changed

onnxruntime/contrib_ops/cuda/sparse/sparse_attention_helper.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -202,13 +202,13 @@ Status CheckInputs(void* params,
202202
"head_size shall be a multiple of 16. Got head_size = ",
203203
head_size);
204204
}
205-
if (cos_dims[0] < max_sequence_length) {
205+
if (cos_dims[0] < total_sequence_length) {
206206
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
207-
"cos_cache dimension 0 should be of max_sequence_length.");
207+
"cos_cache dimension 0 should be not be less than total_sequence_length.");
208208
}
209-
if (sin_dims[0] < max_sequence_length) {
209+
if (sin_dims[0] < total_sequence_length) {
210210
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,
211-
"sin_cache dimension 0 should be of max_sequence_length.");
211+
"sin_cache dimension 0 should be not be less than total_sequence_length.");
212212
}
213213
if (cos_dims[1] > (head_size / 16) * 8 || cos_dims[1] % 8 != 0) {
214214
return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT,

0 commit comments

Comments
 (0)