Skip to content

Commit b5b16e8

Browse files
committed
[ET-VK] Introduce rotary embedding custom op
## Context As title; introduces a custom op to calculate rotary positional embeddings in LLMs. The custom op achieves the same result as the `apply_rotary_emb` Python function. Please see the documentation comments in the shader for more details. Differential Revision: [D64697588](https://our.internmc.facebook.com/intern/diff/D64697588/) ghstack-source-id: 249135613 Pull Request resolved: #6392
1 parent 2d6de75 commit b5b16e8

File tree

5 files changed

+428
-0
lines changed

5 files changed

+428
-0
lines changed
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14+
15+
${define_required_extensions(DTYPE)}
16+
17+
layout(std430) buffer;
18+
19+
${layout_declare_tensor(B, "w", "xqout", DTYPE, STORAGE)}
20+
${layout_declare_tensor(B, "w", "xkout", DTYPE, STORAGE)}
21+
${layout_declare_tensor(B, "r", "xq", DTYPE, STORAGE)}
22+
${layout_declare_tensor(B, "r", "xk", DTYPE, STORAGE)}
23+
${layout_declare_tensor(B, "r", "freqs_cos", DTYPE, STORAGE)}
24+
${layout_declare_tensor(B, "r", "freqs_sin", DTYPE, STORAGE)}
25+
${layout_declare_ubo(B, "ivec3", "xqout_limits")}
26+
${layout_declare_ubo(B, "ivec3", "xkout_limits")}
27+
28+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
29+
30+
layout(constant_id = 3) const int packed_dim = 0;
31+
32+
#include "indexing_utils.h"
33+
34+
/*
35+
* This shader computes rotary positional embeddings which are used in the Llama
36+
* model architecture. There are 4 input tensors with the following shapes.
37+
* Note that head_dim = embedding_dim / num_heads
38+
*
39+
* 1. xq (batch_size, sequence_len, num_heads, head_dim)
40+
* 2. xk (batch_size, sequence_len, num_kv_heads, head_dim)
41+
* 3. freqs_cos (sequence_len, head_dim / 2)
42+
* 4. freqs_cos (sequence_len, head_dim / 2)
43+
*
44+
* Two output tensors are produced, with the same shapes as xq and xk
45+
* respectively.
46+
*
47+
* The computation of rotary positional embeddings can be summarized with the
48+
* following equations:
49+
50+
* xq_out[2i] = xq[2i] * freqs_cos[i] - xq[2i + 1] * freqs_sin[i]
51+
* xq_out[2i + 1] = xq[2i] * freqs_sin[i] + xq[2i + 1] * freqs_cos[i]
52+
*
53+
* Essentially, taking each row along head_dim of the xq and xk tensors, each
54+
* row is split into even and odd elements (xq[2i] and xq[2i + 1] respectively).
55+
* The even components of the output multiply the even components of the inputs
56+
* with the freqs_cos tensor, and the odd components of the inputs with the
57+
* freqs_sin tensor. The odd components of the output swap this. Throughout the
58+
* implements the even components have the _r suffix and the odd components have
59+
* the _i suffix; this is likely a reference to complex numbers which can be
60+
* used to represent rotations.
61+
*
62+
* Note that this implementation assumes that all input tensors have the width
63+
* dim as the packed dim.
64+
*/
65+
void main() {
66+
// Each thread will write to two output locations to maximize data re-use.
67+
// One texel loaded from the freqs_cos/freqs_sin tensors can be used to
68+
// calculate two output texels.
69+
const ivec3 x_pos_1 = ivec3(
70+
gl_GlobalInvocationID.x * 2, gl_GlobalInvocationID.yz);
71+
const ivec3 x_pos_2 = ivec3(x_pos_1.x + 1, x_pos_1.yz);
72+
73+
if (any(greaterThanEqual(x_pos_2, xqout_limits))) {
74+
return;
75+
}
76+
77+
const ivec3 freqs_pos = ivec3(gl_GlobalInvocationID.xz, 0);
78+
79+
VEC4_T cos_tex = load_texel(freqs_cos, freqs_pos);
80+
VEC4_T sin_tex = load_texel(freqs_sin, freqs_pos);
81+
82+
// Compute xqout
83+
84+
VEC4_T x_tex_1 = load_texel(xq, x_pos_1);
85+
VEC4_T x_tex_2 = load_texel(xq, x_pos_2);
86+
87+
// Separate into even and odd elements
88+
VEC4_T x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
89+
VEC4_T x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
90+
91+
VEC4_T xout_r = x_r * cos_tex - x_i * sin_tex;
92+
VEC4_T xout_i = x_r * sin_tex + x_i * cos_tex;
93+
94+
VEC4_T xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
95+
VEC4_T xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
96+
97+
write_texel(xqout, x_pos_1, xout_tex_1);
98+
write_texel(xqout, x_pos_2, xout_tex_2);
99+
100+
if (any(greaterThanEqual(x_pos_2, xkout_limits))) {
101+
return;
102+
}
103+
104+
// Compute xkout
105+
106+
x_tex_1 = load_texel(xk, x_pos_1);
107+
x_tex_2 = load_texel(xk, x_pos_2);
108+
109+
x_r = VEC4_T(x_tex_1.xz, x_tex_2.xz);
110+
x_i = VEC4_T(x_tex_1.yw, x_tex_2.yw);
111+
112+
xout_r = x_r * cos_tex - x_i * sin_tex;
113+
xout_i = x_r * sin_tex + x_i * cos_tex;
114+
115+
xout_tex_1 = VEC4_T(xout_r.x, xout_i.x, xout_r.y, xout_i.y);
116+
xout_tex_2 = VEC4_T(xout_r.z, xout_i.z, xout_r.w, xout_i.w);
117+
118+
write_texel(xkout, x_pos_1, xout_tex_1);
119+
write_texel(xkout, x_pos_2, xout_tex_2);
120+
}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
rotary_embedding:
2+
parameter_names_with_default_values:
3+
DTYPE: float
4+
STORAGE: texture3d
5+
generate_variant_forall:
6+
DTYPE:
7+
- VALUE: half
8+
- VALUE: float
9+
shader_variants:
10+
- NAME: rotary_embedding
Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
10+
11+
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
12+
13+
namespace vkcompute {
14+
15+
void resize_rotary_embedding_node(
16+
ComputeGraph* graph,
17+
const std::vector<ArgGroup>& args,
18+
const std::vector<ValueRef>& extra_args) {
19+
(void)extra_args;
20+
vTensorPtr out = graph->get_tensor(args[0].refs[0]);
21+
vTensorPtr in = graph->get_tensor(args[1].refs[0]);
22+
23+
std::vector<int64_t> in_sizes = in->sizes();
24+
// UNCOMMENT BELOW IF NEEDED
25+
// out->virtual_resize(in_sizes);
26+
}
27+
28+
void add_rotary_embedding_node(
29+
ComputeGraph& graph,
30+
const ValueRef xq,
31+
const ValueRef xk,
32+
const ValueRef freqs_cos,
33+
const ValueRef freqs_sin,
34+
const ValueRef xq_out,
35+
const ValueRef xk_out) {
36+
VK_CHECK_COND(graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, xk));
37+
VK_CHECK_COND(graph.size_at<int>(-3, xq) == graph.size_at<int>(-3, xk));
38+
VK_CHECK_COND(
39+
graph.size_at<int>(-1, xq) == graph.size_at<int>(-1, freqs_cos) * 2);
40+
VK_CHECK_COND(graph.sizes_of(freqs_cos) == graph.sizes_of(freqs_sin));
41+
42+
VK_CHECK_COND(graph.packed_dim_of(xq) == WHCN::kWidthDim);
43+
VK_CHECK_COND(graph.packed_dim_of(xk) == WHCN::kWidthDim);
44+
VK_CHECK_COND(graph.packed_dim_of(freqs_cos) == WHCN::kWidthDim);
45+
VK_CHECK_COND(graph.packed_dim_of(freqs_sin) == WHCN::kWidthDim);
46+
VK_CHECK_COND(graph.has_standard_axis_map(xq));
47+
VK_CHECK_COND(graph.has_standard_axis_map(xk));
48+
VK_CHECK_COND(graph.has_standard_axis_map(freqs_cos));
49+
VK_CHECK_COND(graph.has_standard_axis_map(freqs_sin));
50+
51+
std::string kernel_name = "rotary_embedding";
52+
add_dtype_suffix(kernel_name, graph.dtype_of(xq_out));
53+
54+
utils::uvec3 global_wg_size = graph.logical_limits_of(xq_out);
55+
global_wg_size[0] /= 2;
56+
const utils::uvec3 local_wg_size = graph.create_local_wg_size(global_wg_size);
57+
58+
graph.execute_nodes().emplace_back(new DispatchNode(
59+
graph,
60+
// Shader
61+
VK_KERNEL_FROM_STR(kernel_name),
62+
// Workgroup sizes
63+
global_wg_size,
64+
local_wg_size,
65+
// Inputs and Outputs
66+
{{{xq_out, xk_out}, vkapi::kWrite},
67+
{{xq, xk, freqs_cos, freqs_sin}, vkapi::kRead}},
68+
// Parameter buffers
69+
{graph.logical_limits_ubo(xq_out), graph.logical_limits_ubo(xk_out)},
70+
// Specialization Constants
71+
{},
72+
// Resizing Logic
73+
resize_rotary_embedding_node));
74+
}
75+
76+
void apply_rotary_emb(ComputeGraph& graph, const std::vector<ValueRef>& args) {
77+
const ValueListPtr out_tuple = graph.get_value_list(args[4]);
78+
const ValueRef xq_out = out_tuple->at(0);
79+
const ValueRef xk_out = out_tuple->at(1);
80+
81+
add_rotary_embedding_node(
82+
graph, args[0], args[1], args[2], args[3], xq_out, xk_out);
83+
}
84+
85+
REGISTER_OPERATORS {
86+
VK_REGISTER_OP(et_vk.apply_rotary_emb.default, apply_rotary_emb);
87+
}
88+
89+
} // namespace vkcompute
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <gtest/gtest.h>
10+
11+
#include <ATen/ATen.h>
12+
13+
#include <executorch/backends/vulkan/runtime/api/api.h>
14+
#include <executorch/backends/vulkan/runtime/graph/ComputeGraph.h>
15+
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
16+
17+
#include <cassert>
18+
19+
//
20+
// Reference Implementations
21+
//
22+
23+
std::pair<at::Tensor, at::Tensor> rotary_embedding_impl(
24+
const at::Tensor& xq,
25+
const at::Tensor& xk,
26+
const at::Tensor& freqs_cos,
27+
const at::Tensor& freqs_sin) {
28+
std::vector<at::Tensor> xq_even_odd = at::unbind(
29+
xq.reshape({xq.size(0), xq.size(1), xq.size(2), xq.size(3) / 2, 2}), -1);
30+
at::Tensor& xq_r = xq_even_odd[0];
31+
at::Tensor& xq_i = xq_even_odd[1];
32+
33+
std::vector<at::Tensor> xk_even_odd = at::unbind(
34+
xk.reshape({xk.size(0), xk.size(1), xk.size(2), xk.size(3) / 2, 2}), -1);
35+
at::Tensor& xk_r = xk_even_odd[0];
36+
at::Tensor& xk_i = xk_even_odd[1];
37+
38+
at::Tensor freqs_cos_reshape =
39+
freqs_cos.reshape({1, freqs_cos.size(0), 1, freqs_cos.size(1)});
40+
at::Tensor freqs_sin_reshape =
41+
freqs_sin.reshape({1, freqs_sin.size(0), 1, freqs_sin.size(1)});
42+
43+
at::Tensor xq_out_r = xq_r * freqs_cos_reshape - xq_i * freqs_sin_reshape;
44+
at::Tensor xq_out_i = xq_r * freqs_sin_reshape + xq_i * freqs_cos_reshape;
45+
at::Tensor xk_out_r = xk_r * freqs_cos_reshape - xk_i * freqs_sin_reshape;
46+
at::Tensor xk_out_i = xk_r * freqs_sin_reshape + xk_i * freqs_cos_reshape;
47+
48+
at::Tensor xq_out = at::flatten(at::stack({xq_out_r, xq_out_i}, -1), 3);
49+
at::Tensor xk_out = at::flatten(at::stack({xk_out_r, xk_out_i}, -1), 3);
50+
51+
return std::make_pair(xq_out, xk_out);
52+
}
53+
54+
//
55+
// Test functions
56+
//
57+
58+
vkcompute::vkapi::ScalarType from_at_scalartype(c10::ScalarType at_scalartype) {
59+
using namespace vkcompute;
60+
switch (at_scalartype) {
61+
case c10::kFloat:
62+
return vkapi::kFloat;
63+
case c10::kHalf:
64+
return vkapi::kHalf;
65+
case c10::kInt:
66+
return vkapi::kInt;
67+
case c10::kLong:
68+
return vkapi::kInt;
69+
case c10::kChar:
70+
return vkapi::kChar;
71+
case c10::kByte:
72+
return vkapi::kByte;
73+
default:
74+
VK_THROW("Unsupported at::ScalarType!");
75+
}
76+
}
77+
78+
void test_reference(
79+
const int n_heads = 4,
80+
const int n_kv_heads = 2,
81+
const int dim = 32,
82+
const int seq_len = 1) {
83+
const int head_dim = dim / n_heads;
84+
85+
at::Tensor xq = at::rand(
86+
{1, seq_len, n_heads, head_dim}, at::device(at::kCPU).dtype(at::kFloat));
87+
at::Tensor xk = at::rand(
88+
{1, seq_len, n_kv_heads, head_dim},
89+
at::device(at::kCPU).dtype(at::kFloat));
90+
at::Tensor freqs_cos =
91+
at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat));
92+
at::Tensor freqs_sin =
93+
at::rand({seq_len, head_dim / 2}, at::device(at::kCPU).dtype(at::kFloat));
94+
95+
std::pair<at::Tensor, at::Tensor> outs =
96+
rotary_embedding_impl(xq, xk, freqs_cos, freqs_sin);
97+
at::Tensor& xq_out = outs.first;
98+
at::Tensor& xk_out = outs.second;
99+
100+
// Build Vulkan graph
101+
using namespace vkcompute;
102+
103+
GraphConfig config;
104+
config.set_storage_type_override(utils::kTexture3D);
105+
ComputeGraph graph(config);
106+
107+
#define MAKE_INPUT_FOR(x) \
108+
IOValueRef r_##x = graph.add_input_tensor( \
109+
x.sizes().vec(), from_at_scalartype(x.scalar_type()));
110+
111+
MAKE_INPUT_FOR(xq);
112+
MAKE_INPUT_FOR(xk);
113+
MAKE_INPUT_FOR(freqs_cos);
114+
MAKE_INPUT_FOR(freqs_sin);
115+
116+
const ValueRef r_xq_out = graph.add_tensor(
117+
xq_out.sizes().vec(), from_at_scalartype(xq_out.scalar_type()));
118+
const ValueRef r_xk_out = graph.add_tensor(
119+
xk_out.sizes().vec(), from_at_scalartype(xk_out.scalar_type()));
120+
121+
VK_GET_OP_FN("et_vk.apply_rotary_emb.default")
122+
(graph,
123+
{r_xq.value,
124+
r_xk.value,
125+
r_freqs_cos.value,
126+
r_freqs_sin.value,
127+
graph.add_value_list({r_xq_out, r_xk_out})});
128+
129+
ValueRef staging_xq_out = graph.set_output_tensor(r_xq_out);
130+
ValueRef staging_xk_out = graph.set_output_tensor(r_xk_out);
131+
132+
graph.prepare();
133+
graph.encode_prepack();
134+
graph.prepack();
135+
graph.encode_execute();
136+
137+
//
138+
// Run model
139+
//
140+
141+
graph.propagate_resize();
142+
graph.copy_into_staging(r_xq.staging, xq.const_data_ptr(), xq.numel());
143+
graph.copy_into_staging(r_xk.staging, xk.const_data_ptr(), xk.numel());
144+
graph.copy_into_staging(
145+
r_freqs_cos.staging, freqs_cos.const_data_ptr(), freqs_cos.numel());
146+
graph.copy_into_staging(
147+
r_freqs_sin.staging, freqs_sin.const_data_ptr(), freqs_sin.numel());
148+
149+
graph.execute();
150+
151+
at::Tensor vk_xq_out = at::empty_like(xq_out);
152+
graph.copy_from_staging(
153+
staging_xq_out, vk_xq_out.mutable_data_ptr(), vk_xq_out.numel());
154+
155+
at::Tensor vk_xk_out = at::empty_like(xk_out);
156+
graph.copy_from_staging(
157+
staging_xk_out, vk_xk_out.mutable_data_ptr(), vk_xk_out.numel());
158+
159+
EXPECT_TRUE(at::allclose(xq_out, vk_xq_out));
160+
EXPECT_TRUE(at::allclose(xk_out, vk_xk_out));
161+
}
162+
163+
TEST(VulkanRotaryEmbeddingTest, rotary_embedding_test) {
164+
test_reference();
165+
}
166+
167+
TEST(VulkanRotaryEmbeddingTest, rotary_embedding_llama3_params_test) {
168+
test_reference(
169+
/*n_heads=*/32,
170+
/*n_kv_heads=*/8,
171+
/*dim=*/2048);
172+
}

0 commit comments

Comments
 (0)