Skip to content

Commit 6547b38

Browse files
committed
Update on "[ET-VK] Introduce rotary embedding custom op"
## Context As title; introduces a custom op to calculate rotary positional embeddings in LLMs. The custom op achieves the same result as the `apply_rotary_emb` Python function. Please see the documentation comments in the shader for more details. Differential Revision: [D64697588](https://our.internmc.facebook.com/intern/diff/D64697588/) [ghstack-poisoned]
2 parents fda00c2 + f914d9b commit 6547b38

File tree

2 files changed

+17
-6
lines changed

2 files changed

+17
-6
lines changed

backends/vulkan/runtime/graph/ops/glsl/rotary_embedding.glsl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ layout(constant_id = 3) const int packed_dim = 0;
4646
*
4747
* The computation of rotary positional embeddings can be summarized with the
4848
* following equations:
49-
49+
*
5050
* xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i]
5151
* xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i]
5252
*
@@ -55,9 +55,9 @@ layout(constant_id = 3) const int packed_dim = 0;
5555
* The even components of the output multiply the even components of the inputs
5656
* with the freqs_cos tensor, and the odd components of the inputs with the
5757
* freqs_sin tensor. The odd components of the output swap this. Throughout the
58-
* implements the even components have the _r suffix and the odd components have
59-
* the _i suffix; this is likely a reference to complex numbers which can be
60-
* used to represent rotations.
58+
* implementation the even components have the _r suffix and the odd components
59+
* have the _i suffix; this is a reference to complex numbers which can be used
60+
* to represent rotations.
6161
*
6262
* Note that this implementation assumes that all input tensors have the width
6363
* dim as the packed dim.
@@ -97,6 +97,9 @@ void main() {
9797
write_texel(xqout, x_pos_1, xout_tex_1);
9898
write_texel(xqout, x_pos_2, xout_tex_2);
9999

100+
// n_heads will be greater than or equal to n_kv_heads, therefore xq and xqout
101+
// may have a larger height dim than xk and xkout. Only compute xkout if this
102+
// invocation is still within bounds.
100103
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
101104
return;
102105
}

backends/vulkan/test/op_tests/rotary_embedding_test.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,8 +156,8 @@ void test_reference(
156156
graph.copy_from_staging(
157157
staging_xk_out, vk_xk_out.mutable_data_ptr(), vk_xk_out.numel());
158158

159-
EXPECT_TRUE(at::allclose(xq_out, vk_xq_out));
160-
EXPECT_TRUE(at::allclose(xk_out, vk_xk_out));
159+
EXPECT_TRUE(at::allclose(xq_out, vk_xq_out, 1e-4, 1e-4));
160+
EXPECT_TRUE(at::allclose(xk_out, vk_xk_out, 1e-4, 1e-4));
161161
}
162162

163163
TEST(VulkanRotaryEmbeddingTest, rotary_embedding_test) {
@@ -170,3 +170,11 @@ TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test) {
170170
/*n_kv_heads=*/8,
171171
/*dim=*/2048);
172172
}
173+
174+
TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test_seq_len_3) {
175+
test_reference(
176+
/*n_heads=*/32,
177+
/*n_kv_heads=*/8,
178+
/*dim=*/2048,
179+
/*seq_len=*/3);
180+
}

0 commit comments

Comments
 (0)