Skip to content

Commit 9a15422

Browse files
authored
Fix rotary embedding to follow spec on attribute constraints (#26044)
### Description <!-- Describe your changes. --> `num_heads` is not necessarily required from users when input shape is 4D. ### Motivation and Context <!-- - Why is this change required? What problem does it solve? - If it fixes an open issue, please link to the issue here. --> To follow ONNX spec, https://github.com/onnx/onnx/blob/main/docs/Operators.md#RotaryEmbedding, the original constraints on attributes were wrong. NOTE: 3 rotary embedding tests are expected to be wrong until next release.
1 parent f3251de commit 9a15422

File tree

4 files changed

+22
-7
lines changed

4 files changed

+22
-7
lines changed

onnxruntime/core/providers/cpu/llm/rotary_embedding.cc

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,6 @@ RotaryEmbedding<T>::RotaryEmbedding(const OpKernelInfo& info) : OpKernel(info) {
3030
num_heads = static_cast<int>(info.GetAttrOrDefault<int64_t>("num_heads", 0));
3131
rotary_embedding_dim = static_cast<int>(info.GetAttrOrDefault<int64_t>("rotary_embedding_dim", 0));
3232
interleaved = (info.GetAttrOrDefault<int64_t>("interleaved", 0) == 1); // Turn 0/1 into bool
33-
34-
if (rotary_embedding_dim > 0) {
35-
ORT_ENFORCE(num_heads > 0, "num_heads must be provided if rotary_embedding_dim is specified");
36-
}
3733
}
3834

3935
// TODO: rotary embedding in place
@@ -111,6 +107,15 @@ Status RotaryEmbedding<T>::Compute(OpKernelContext* context) const {
111107
// Optional position_ids input, can be nullptr
112108
const Tensor* position_ids = context->Input<Tensor>(3);
113109

110+
// If rotary_embedding_dim is set (>0) and num_heads attribute not provided (==0),
111+
// we can only proceed if input is 4D (B, num_heads, S, head_size) so num_heads can be inferred.
112+
if (rotary_embedding_dim > 0 && num_heads <= 0) {
113+
const auto& dims = X->Shape().GetDims();
114+
ORT_ENFORCE(dims.size() == 4,
115+
"Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified "
116+
"and input is not rank-4 (batch, num_heads, sequence, head).");
117+
}
118+
114119
RotaryParameters parameters = {};
115120
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(X,
116121
position_ids,

onnxruntime/core/providers/cuda/llm/rotary_embedding.cc

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ Status RotaryEmbedding<T>::ComputeInternal(OpKernelContext* context) const {
4444
const Tensor* sin_cache = context->Input<Tensor>(2);
4545
const Tensor* position_ids = context->Input<Tensor>(3); // Optional, can be nullptr
4646

47+
// If rotary_embedding_dim is set (>0) and num_heads attribute not provided (==0),
48+
// we can only proceed if input is 4D (B, num_heads, S, head_size) so num_heads can be inferred.
49+
if (rotary_embedding_dim > 0 && num_heads <= 0) {
50+
const auto& dims = input->Shape().GetDims();
51+
ORT_ENFORCE(dims.size() == 4,
52+
"Attribute 'num_heads' must be provided when 'rotary_embedding_dim' is specified "
53+
"and input is not rank-4 (batch, num_heads, sequence, head).");
54+
}
55+
4756
RotaryParameters parameters = {};
4857
ORT_RETURN_IF_ERROR(rotary_embedding_helper::CheckInputs<Tensor>(input,
4958
position_ids,

onnxruntime/test/onnx/TestCase.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1408,6 +1408,9 @@ std::unique_ptr<std::set<BrokenTest>> GetBrokenTests(const std::string& provider
14081408
broken_tests->insert({"gridsample_volumetric_nearest_align_corners_1", "unknown version"});
14091409
broken_tests->insert({"rotary_embedding_no_position_ids_expanded", "unknown version"});
14101410
broken_tests->insert({"rotary_embedding_no_position_ids_interleaved_expanded", "unknown version"});
1411+
broken_tests->insert({"rotary_embedding_no_position_ids_rotary_dim", "unknown version"});
1412+
broken_tests->insert({"rotary_embedding_with_interleaved_rotary_dim", "unknown version"});
1413+
broken_tests->insert({"rotary_embedding_with_rotary_dim", "unknown version"});
14111414
// Fails since QNN SDK 2.17.0:
14121415
// expected 7.70947 (40f6b3f3), got 7.84096 (40fae920), diff: 0.131491, tol=0.00870947 idx=419. 100 of 1715 differ
14131416
broken_tests->insert({"facedetection_op8_qdq", "result differs"});

onnxruntime/test/testdata/onnx_backend_test_series_filters.jsonc

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,12 +123,10 @@
123123
"^test_if_opt",
124124
"^test_loop16_seq_none",
125125
"^test_identity_opt",
126+
// rotary dim should be fixed in onnx==1.19.1
126127
"^test_rotary_embedding_no_position_ids_rotary_dim",
127128
"^test_rotary_embedding_with_interleaved_rotary_dim",
128129
"^test_rotary_embedding_with_rotary_dim",
129-
"^test_rotary_embedding_3d_input_expanded",
130-
"^test_rotary_embedding_interleaved_expanded",
131-
"^test_rotary_embedding_no_position_ids_interleaved_expanded",
132130
"^test_rotary_embedding_expanded", //webgpu
133131
"^test_rotary_embedding_no_position_ids_expanded", //webgpu
134132
// Following tests are for opset 16 ops and are not yet implemented in ORT

0 commit comments

Comments
 (0)