Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 30 additions & 37 deletions backends/vulkan/runtime/graph/ops/glsl/q_8w_linear.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,26 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
#define FLOAT_T float
#endif

FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
const FLOAT_T scale = t_scales[out_idx.x];
void main() {
const int out_bufi = int(gl_GlobalInvocationID.x);
if (out_bufi >= out_numel) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0);

const FLOAT_T scale = t_scales[out_tidx.x];

FLOAT_T outval = FLOAT_T(0.0);

// Initial mat1 tensor idx will be (0, out_idx.y, out_idx.z, 0)
int mat1_offset = out_idx.y * mat1_strides.y + out_idx.z * qmat2_strides.z;
// Initial qmat2 tensor idx wil be (0, out_idx.x, 0, 0); note that the qmat2
// Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0)
int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z;
// Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2
// tensor is transposed
int qmat2_offset = out_idx.x * qmat2_strides.y;
int qmat2_offset = out_tidx.x * qmat2_strides.y;

// TODO(ssjia): optimize memory access pattern by traversing K in inner loop
for (int i = 0; i < K; i++) {
// TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
for (int i = 0; i < mat1_sizes.x; i++) {
const FLOAT_T mat1_val = t_mat1[mat1_offset];
const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;

Expand All @@ -74,33 +81,32 @@ FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
qmat2_offset++;
}

return outval;
}

void main() {
const int out_bufi = int(gl_GlobalInvocationID.x);
if (out_bufi >= out_numel) {
return;
}

const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0);

t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x);
t_out[out_bufi] = outval;
}

#else // USING_TEXTURE

#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require

VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) {
void main() {
const u16vec2 out_pos = u16vec2(
gl_GlobalInvocationID.x / out_limits.y,
gl_GlobalInvocationID.x % out_limits.y);
if (out_pos.x >= out_limits.x) {
return;
}

const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4);

VEC4_T outtex = VEC4_T(0);

const u16vec3 scales_pos = u16vec3(out_pos.x, 0, 0);
const VEC4_T scales = load_texel(t_scales, scales_pos);
const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0, 0));

for (uint16_t i = uint16_t(0), x = uint16_t(0); i < K; i += uint16_t(4), x++) {
for (
uint16_t i = uint16_t(0), x = uint16_t(0);
i < uint16_t(mat1_sizes.x);
i += uint16_t(4), x++)
{
const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0));
const VEC4_T sums = VEC4_T(
dot(mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0))),
Expand All @@ -112,19 +118,6 @@ VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) {
}

outtex *= scales;

return outtex;
}

void main() {
const u16vec2 out_pos = u16vec2(
gl_GlobalInvocationID.x / out_limits.y,
gl_GlobalInvocationID.x % out_limits.y);
if (out_pos.x >= out_limits.x) {
return;
}

VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x));
write_texel(t_out, u16vec3(out_pos, 0), outtex);
}

Expand Down
Loading