Skip to content

Commit ecb6852

Browse files
author
ssjia
committed
[ET-VK] Contiguous buffer implementation for embedding
Pull Request resolved: #15160 Title says it all! Also moved around some structs throughout the `glslh` files for better code organization. ghstack-source-id: 317368559 @exported-using-ghexport Differential Revision: [D84716456](https://our.internmc.facebook.com/intern/diff/D84716456/)
1 parent 0ecea15 commit ecb6852

16 files changed

+360
-51
lines changed

backends/vulkan/runtime/graph/ops/glsl/common.glslh

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,10 +29,6 @@
2929
#define mod_4(x) ((x) & 3)
3030
#define mod_8(x) ((x) & 7)
3131

32-
struct TensorIndex4D {
33-
ivec4 data;
34-
};
35-
3632
int sign_extend_8bit(const int val) {
3733
if ((val & 0x80) != 0) {
3834
return val | (~0xFF);
@@ -86,19 +82,4 @@ int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) {
8682
return pack_into_int32(quantized);
8783
}
8884

89-
#ifdef DEBUG_MODE
90-
91-
#extension GL_EXT_debug_printf : require
92-
93-
void printTensorIndex4D(const TensorIndex4D index) {
94-
debugPrintfEXT(
95-
"tensor_idx: %d, %d, %d, %d\\n",
96-
index.data.x,
97-
index.data.y,
98-
index.data.z,
99-
index.data.w);
100-
}
101-
102-
#endif // DEBUG_MODE
103-
10485
#endif // COMMON_GLSLH

backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block.glslh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
#extension GL_EXT_control_flow_attributes : require
2525

26-
#include "common.glslh"
26+
#include "indexing.glslh"
2727
#include "conv2d_common.glslh"
2828

2929
struct Im2ColMatrixIdx {

backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_load.glslh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
#extension GL_EXT_debug_printf : require
2525

26-
#include "common.glslh"
26+
#include "indexing.glslh"
2727
#include "conv2d_common.glslh"
2828
#include "conv2d_fp_im2col_block.glslh"
2929
#include "linear_fp_input_tile.glslh"

backends/vulkan/runtime/graph/ops/glsl/conv2d_fp_im2col_block_store.glslh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020

2121
#extension GL_EXT_control_flow_attributes : require
2222

23-
#include "common.glslh"
23+
#include "indexing.glslh"
2424
#include "conv2d_common.glslh"
2525
#include "conv2d_fp_im2col_block.glslh"
2626
#include "linear_fp_output_tile.glslh"
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define T ${buffer_scalar_type(DTYPE)}
14+
15+
${define_active_storage_type("buffer")}
16+
${define_required_extensions(DTYPE)}
17+
18+
#extension GL_EXT_control_flow_attributes : require
19+
20+
layout(std430) buffer;
21+
22+
#define DEBUG_MODE
23+
#include "indexing.glslh"
24+
25+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
26+
${layout_declare_tensor(B, "r", "t_indices", "int", "buffer")}
27+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")}
28+
29+
${layout_declare_ubo(B, "BufferMetadata", "outp")}
30+
${layout_declare_ubo(B, "BufferMetadata", "indices")}
31+
${layout_declare_ubo(B, "BufferMetadata", "weight")}
32+
33+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
34+
35+
TensorIndex out_tidx_to_indices_tidx(const TensorIndex out_tidx) {
36+
TensorIndex indices_tidx;
37+
int d = 0;
38+
// First half of the index
39+
[[unroll]] for (uint d = 0; d < ndim(indices); ++d) {
40+
indices_tidx.data[div_4(d)][mod_4(d)] = idx_at(out_tidx, d + 1);
41+
}
42+
[[unroll]] for (uint d = ndim(indices); d < DIMLIMIT; ++d) {
43+
indices_tidx.data[div_4(d)][mod_4(d)] = 0;
44+
}
45+
return indices_tidx;
46+
}
47+
48+
int load_embedding_idx(const TensorIndex indices_tidx) {
49+
const uint bufi = tensor_idx_to_linear_idx(indices, indices_tidx);
50+
return t_indices[bufi];
51+
}
52+
53+
T load_weight_elem(const int embedding_idx, const uint dim_idx) {
54+
uint bufi = uint(embedding_idx) * width(weight) + dim_idx;
55+
return t_weight[bufi];
56+
}
57+
58+
void main() {
59+
const uint out_bufi = gl_GlobalInvocationID.x;
60+
if (out_of_bounds(out_bufi, outp)) {
61+
return;
62+
}
63+
64+
TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi);
65+
TensorIndex indices_tidx = out_tidx_to_indices_tidx(out_tidx);
66+
67+
const uint bufi = tensor_idx_to_linear_idx(indices, indices_tidx);
68+
const int embedding_idx = load_embedding_idx(indices_tidx);
69+
70+
t_out[out_bufi] = load_weight_elem(embedding_idx, x(out_tidx));
71+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
embedding_buffer:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
STORAGE: buffer
11+
generate_variant_forall:
12+
DTYPE:
13+
- VALUE: half
14+
- VALUE: float
15+
shader_variants:
16+
- NAME: embedding_buffer

backends/vulkan/runtime/graph/ops/glsl/embedding.glsl renamed to backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.glsl

File renamed without changes.

backends/vulkan/runtime/graph/ops/glsl/embedding.yaml renamed to backends/vulkan/runtime/graph/ops/glsl/embedding_legacy.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
embedding:
1+
embedding_legacy:
22
parameter_names_with_default_values:
33
DTYPE: float
44
NDIM: 3
@@ -9,4 +9,4 @@ embedding:
99
- VALUE: float
1010
- VALUE: int32
1111
shader_variants:
12-
- NAME: embedding
12+
- NAME: embedding_legacy
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#version 450 core
10+
11+
#define PRECISION ${PRECISION}
12+
13+
#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
14+
#define T ${texel_load_component_type(DTYPE, "texture3d")}
15+
16+
${define_active_storage_type("texture3d")}
17+
${define_required_extensions(DTYPE)}
18+
19+
#extension GL_EXT_control_flow_attributes : require
20+
21+
layout(std430) buffer;
22+
23+
#define DEBUG_MODE
24+
#include "common.glslh"
25+
#include "indexing.glslh"
26+
27+
${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
28+
${layout_declare_tensor(B, "r", "t_indices", "int", "texture3d")}
29+
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")}
30+
31+
${layout_declare_ubo(B, "TextureMetadata", "outp")}
32+
${layout_declare_ubo(B, "TextureMetadata", "indices")}
33+
${layout_declare_ubo(B, "BufferMetadata", "weight")}
34+
35+
layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
36+
37+
int load_embedding_idx(const TensorIndex4D out_tidx) {
38+
TensorIndex4D indices_tidx;
39+
indices_tidx.data.xyz = out_tidx.data.yzw;
40+
indices_tidx.data.w = 0;
41+
42+
TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple(
43+
indices_tidx, indices);
44+
45+
const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0);
46+
return in_texel[elem_pos.comp];
47+
}
48+
49+
VEC4_T load_weight_texel(const int embedding_idx, const int dim_idx) {
50+
int buf_i = embedding_idx * int(width(weight)) + dim_idx;
51+
VEC4_T weight_texel;
52+
[[unroll]] for (int i = 0; i < 4; ++i) {
53+
weight_texel[i] = T(t_weight[buf_i++]);
54+
}
55+
return weight_texel;
56+
}
57+
58+
void main() {
59+
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);
60+
61+
if (out_of_bounds(out_pos, outp)) {
62+
return;
63+
}
64+
65+
TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp);
66+
const int embedding_idx = load_embedding_idx(out_tidx);
67+
68+
const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x);
69+
70+
imageStore(t_out, out_pos, weight_texel);
71+
}
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
embedding_texture:
8+
parameter_names_with_default_values:
9+
DTYPE: float
10+
generate_variant_forall:
11+
DTYPE:
12+
- VALUE: half
13+
- VALUE: float
14+
shader_variants:
15+
- NAME: embedding_texture3d

0 commit comments

Comments
 (0)