Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions backends/vulkan/runtime/graph/ops/glsl/embedding.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ void main() {
const ivec3 in_lpos = ivec3(out_tidx.y, out_tidx.z * 4 + i, out_tidx.w / 4);
const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4];

// Read weight tensor for embedding.
const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem, 0);
out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map).x;
// Read weight tensor for embedding, it is height-packed.
const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem / 4, 0);
out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map)[in_texel_elem % 4];
}

write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map);
Expand Down
17 changes: 15 additions & 2 deletions backends/vulkan/runtime/graph/ops/impl/Embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,21 @@

#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>

#include <executorch/backends/vulkan/runtime/utils/StorageUtils.h>

namespace vkcompute {

using utils::StorageType;
using utils::GPUMemoryLayout;

void check_embedding_args(
const api::vTensor& weight,
const api::vTensor& in,
const api::vTensor& out) {
VK_CHECK_COND(check_packed_dim_is(weight, WHCN::kChannelsDim));
// The packing logic may not be trivial here. Input and output are Channel
// Packed, which is default for the Vulkan backend. However, weight vector is
// height-packed instead of channel-packed for space reason.
VK_CHECK_COND(check_packed_dim_is(weight, WHCN::kHeightDim));
VK_CHECK_COND(check_packed_dim_is(in, WHCN::kChannelsDim));
VK_CHECK_COND(check_packed_dim_is(out, WHCN::kChannelsDim));
}
Expand Down Expand Up @@ -58,7 +66,12 @@ void add_embedding_node(
void embedding(ComputeGraph& graph, const std::vector<ValueRef>& args) {
ValueRef in = args[1];
ValueRef out = args[5];
ValueRef weight = prepack_standard_like(graph, args[0], out);

ValueRef weight = prepack_standard(
graph,
args[0],
StorageType::TEXTURE_2D,
GPUMemoryLayout::TENSOR_HEIGHT_PACKED);

add_embedding_node(graph, weight, in, out);
}
Expand Down
Loading