Skip to content

Commit 633057c

Browse files
[ET-VK] Change weight packing in embedding
Pull Request resolved: #7063 The existing weight tensor for aten.embedding is created using a `tensor_like` from the output tensor, which defaults to channel packed. However, the weight tensor is actually a 2D-tensor of `(num_embedding, dim_of_embedding)`. It is better in space to use either width or height packing. This diff changes the implementation to use height-packing. ghstack-source-id: 255439082 Differential Revision: [D66421366](https://our.internmc.facebook.com/intern/diff/D66421366/) Co-authored-by: Justin Yip <[email protected]>
1 parent 2a292c3 commit 633057c

File tree

2 files changed

+18
-5
lines changed

2 files changed

+18
-5
lines changed

backends/vulkan/runtime/graph/ops/glsl/embedding.glsl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,9 @@ void main() {
4747
const ivec3 in_lpos = ivec3(out_tidx.y, out_tidx.z * 4 + i, out_tidx.w / 4);
4848
const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4];
4949

50-
// Read weight tensor for embedding.
51-
const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem, 0);
52-
out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map).x;
50+
// Read weight tensor for embedding, it is height-packed.
51+
const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem / 4, 0);
52+
out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map)[in_texel_elem % 4];
5353
}
5454

5555
write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map);

backends/vulkan/runtime/graph/ops/impl/Embedding.cpp

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,21 @@
1515

1616
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
1717

18+
#include <executorch/backends/vulkan/runtime/utils/StorageUtils.h>
19+
1820
namespace vkcompute {
1921

22+
using utils::GPUMemoryLayout;
23+
using utils::StorageType;
24+
2025
void check_embedding_args(
2126
const api::vTensor& weight,
2227
const api::vTensor& in,
2328
const api::vTensor& out) {
24-
VK_CHECK_COND(check_packed_dim_is(weight, WHCN::kChannelsDim));
29+
// The packing logic may not be trivial here. Input and output are Channel
30+
// Packed, which is default for the Vulkan backend. However, weight vector is
31+
// height-packed instead of channel-packed for space reason.
32+
VK_CHECK_COND(check_packed_dim_is(weight, WHCN::kHeightDim));
2533
VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
2634
VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
2735
}
@@ -58,7 +66,12 @@ void add_embedding_node(
5866
void embedding(ComputeGraph& graph, const std::vector<ValueRef>& args) {
5967
ValueRef in = args[1];
6068
ValueRef out = args[5];
61-
ValueRef weight = prepack_standard_like(graph, args[0], out);
69+
70+
ValueRef weight = prepack_standard(
71+
graph,
72+
args[0],
73+
StorageType::TEXTURE_2D,
74+
GPUMemoryLayout::TENSOR_HEIGHT_PACKED);
6275

6376
add_embedding_node(graph, weight, in, out);
6477
}

0 commit comments

Comments
 (0)