Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions backends/vulkan/runtime/graph/ops/glsl/common.glslh
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,6 @@
#define mod_4(x) ((x) & 3)
#define mod_8(x) ((x) & 7)

struct TensorIndex4D {
ivec4 data;
};

int sign_extend_8bit(const int val) {
if ((val & 0x80) != 0) {
return val | (~0xFF);
Expand Down Expand Up @@ -86,19 +82,4 @@ int quantize_and_pack(const vec4 vals, const float inv_scale, const int zp) {
return pack_into_int32(quantized);
}

#ifdef DEBUG_MODE

#extension GL_EXT_debug_printf : require

void printTensorIndex4D(const TensorIndex4D index) {
debugPrintfEXT(
"tensor_idx: %d, %d, %d, %d\\n",
index.data.x,
index.data.y,
index.data.z,
index.data.w);
}

#endif // DEBUG_MODE

#endif // COMMON_GLSLH
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

#extension GL_EXT_control_flow_attributes : require

#include "common.glslh"
#include "indexing.glslh"
#include "conv2d_common.glslh"

struct Im2ColMatrixIdx {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

#extension GL_EXT_debug_printf : require

#include "common.glslh"
#include "indexing.glslh"
#include "conv2d_common.glslh"
#include "conv2d_fp_im2col_block.glslh"
#include "linear_fp_input_tile.glslh"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

#extension GL_EXT_control_flow_attributes : require

#include "common.glslh"
#include "indexing.glslh"
#include "conv2d_common.glslh"
#include "conv2d_fp_im2col_block.glslh"
#include "linear_fp_output_tile.glslh"
Expand Down
71 changes: 71 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define T ${buffer_scalar_type(DTYPE)}

${define_active_storage_type("buffer")}
${define_required_extensions(DTYPE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

#define DEBUG_MODE
#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "buffer")}
${layout_declare_tensor(B, "r", "t_indices", "int", "buffer")}
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")}

${layout_declare_ubo(B, "BufferMetadata", "outp")}
${layout_declare_ubo(B, "BufferMetadata", "indices")}
${layout_declare_ubo(B, "BufferMetadata", "weight")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

TensorIndex out_tidx_to_indices_tidx(const TensorIndex out_tidx) {
TensorIndex indices_tidx;
int d = 0;
// First half of the index
[[unroll]] for (uint d = 0; d < ndim(indices); ++d) {
indices_tidx.data[div_4(d)][mod_4(d)] = idx_at(out_tidx, d + 1);
}
[[unroll]] for (uint d = ndim(indices); d < DIMLIMIT; ++d) {
indices_tidx.data[div_4(d)][mod_4(d)] = 0;
}
return indices_tidx;
}

int load_embedding_idx(const TensorIndex indices_tidx) {
const uint bufi = tensor_idx_to_linear_idx(indices, indices_tidx);
return t_indices[bufi];
}

T load_weight_elem(const int embedding_idx, const uint dim_idx) {
uint bufi = uint(embedding_idx) * width(weight) + dim_idx;
return t_weight[bufi];
}

void main() {
const uint out_bufi = gl_GlobalInvocationID.x;
if (out_of_bounds(out_bufi, outp)) {
return;
}

TensorIndex out_tidx = linear_idx_to_tensor_idx(outp, out_bufi);
TensorIndex indices_tidx = out_tidx_to_indices_tidx(out_tidx);

const uint bufi = tensor_idx_to_linear_idx(indices, indices_tidx);
const int embedding_idx = load_embedding_idx(indices_tidx);

t_out[out_bufi] = load_weight_elem(embedding_idx, x(out_tidx));
}
16 changes: 16 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/embedding_buffer.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

embedding_buffer:
parameter_names_with_default_values:
DTYPE: float
STORAGE: buffer
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: embedding_buffer
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
embedding:
embedding_legacy:
parameter_names_with_default_values:
DTYPE: float
NDIM: 3
Expand All @@ -9,4 +9,4 @@ embedding:
- VALUE: float
- VALUE: int32
shader_variants:
- NAME: embedding
- NAME: embedding_legacy
71 changes: 71 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/embedding_texture.glsl
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/

#version 450 core

#define PRECISION ${PRECISION}

#define VEC4_T ${texel_load_type(DTYPE, "texture3d")}
#define T ${texel_load_component_type(DTYPE, "texture3d")}

${define_active_storage_type("texture3d")}
${define_required_extensions(DTYPE)}

#extension GL_EXT_control_flow_attributes : require

layout(std430) buffer;

#define DEBUG_MODE
#include "common.glslh"
#include "indexing.glslh"

${layout_declare_tensor(B, "w", "t_out", DTYPE, "texture3d")}
${layout_declare_tensor(B, "r", "t_indices", "int", "texture3d")}
${layout_declare_tensor(B, "r", "t_weight", DTYPE, "buffer")}

${layout_declare_ubo(B, "TextureMetadata", "outp")}
${layout_declare_ubo(B, "TextureMetadata", "indices")}
${layout_declare_ubo(B, "BufferMetadata", "weight")}

layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;

int load_embedding_idx(const TensorIndex4D out_tidx) {
TensorIndex4D indices_tidx;
indices_tidx.data.xyz = out_tidx.data.yzw;
indices_tidx.data.w = 0;

TextureElementIndex elem_pos = tensor_idx_to_texture_element_idx_simple(
indices_tidx, indices);

const ivec4 in_texel = texelFetch(t_indices, elem_pos.pos, 0);
return in_texel[elem_pos.comp];
}

VEC4_T load_weight_texel(const int embedding_idx, const int dim_idx) {
int buf_i = embedding_idx * int(width(weight)) + dim_idx;
VEC4_T weight_texel;
[[unroll]] for (int i = 0; i < 4; ++i) {
weight_texel[i] = T(t_weight[buf_i++]);
}
return weight_texel;
}

void main() {
const ivec3 out_pos = ivec3(gl_GlobalInvocationID);

if (out_of_bounds(out_pos, outp)) {
return;
}

TensorIndex4D out_tidx = texture_pos_to_tensor_idx_simple(out_pos, outp);
const int embedding_idx = load_embedding_idx(out_tidx);

const VEC4_T weight_texel = load_weight_texel(embedding_idx, out_tidx.data.x);

imageStore(t_out, out_pos, weight_texel);
}
15 changes: 15 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/embedding_texture.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

embedding_texture:
parameter_names_with_default_values:
DTYPE: float
generate_variant_forall:
DTYPE:
- VALUE: half
- VALUE: float
shader_variants:
- NAME: embedding_texture3d
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
#ifndef IM2COL_PACKED_INT8_GLSLH
#define IM2COL_PACKED_INT8_GLSLH

#include "common.glslh"
#include "indexing.glslh"

struct Conv2dBlockElementIndex {
int x4;
Expand Down
Loading
Loading