Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,17 @@

${define_required_extensions(DTYPE)}

$if WEIGHT_STORAGE == "buffer":
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)}
$if QUANT_NBITS == 4:
${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
$if WEIGHT_STORAGE == "buffer":
${layout_declare_tensor(B, "r", "t_weight", "uint", WEIGHT_STORAGE, is_scalar_array=True)}
$else:
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
$if QUANT_NBITS == 4:
${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)}
$else:
${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)}
${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)}


Expand Down Expand Up @@ -91,22 +91,20 @@ void main() {
$if WEIGHT_STORAGE == "buffer":
uint qmat2_bufi;
uint weight_row_txstride = div4(weight_sizes.x);
uint encoded_weight;

// Preload weight tensor
for (int r = 0; r < 4; r++) {
T qmat2[TILE_TXCOLS * 4];
VEC4_T qmat2_vec4;
uvec4 packed_weight_tex;

$if QUANT_NBITS == 4:
$if WEIGHT_STORAGE == "buffer":
u8vec4 packed_weight_tex;
$else:
uvec4 packed_weight_tex;

$for c in range(0, TILE_TXCOLS, 2):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol;
packed_weight_tex = t_weight[qmat2_bufi + ${c}]
encoded_weight = t_weight[qmat2_bufi + ${c}];
packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
$else:
packed_weight_tex = texelFetch(
t_weight, ivec2(weight_txcol + ${c}, pos + r), 0);
Expand All @@ -126,7 +124,9 @@ void main() {
$for c in range(TILE_TXCOLS):
$if WEIGHT_STORAGE == "buffer":
qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol;
qmat2_vec4 = t_weight[qmat2_bufi + ${c}];
encoded_weight = t_weight[qmat2_bufi + ${c}];
packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24);
qmat2_vec4 = VEC4_T(packed_weight_tex);
$else:
qmat2_vec4 = VEC4_T(texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0));
$for j in range(4):
Expand Down
10 changes: 10 additions & 0 deletions backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,18 @@ linear_qcsnw_tiled:
- NAME: linear_qcs4w_tiled_texture3d_texture3d_texture2d_texture2d_float
TILE_TXCOLS: 2
QUANT_NBITS: 4
- NAME: linear_qcs4w_tiled_texture3d_texture3d_buffer_texture2d_float
TILE_TXCOLS: 2
QUANT_NBITS: 4
WEIGHT_STORAGE: buffer
- NAME: linear_qcs4w_tiled_buffer_buffer_texture2d_texture2d_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
TILE_TXCOLS: 2
QUANT_NBITS: 4
- NAME: linear_qcs4w_tiled_buffer_buffer_buffer_texture2d_float
IN_STORAGE: buffer
OUT_STORAGE: buffer
WEIGHT_STORAGE: buffer
TILE_TXCOLS: 2
QUANT_NBITS: 4
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@

$if not NO_INT8_BUFFERS:
${define_required_extensions("uint8")}
$if STORAGE == "buffer":
${define_required_extensions("int8")}

layout(std430) buffer;

${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}
$if STORAGE == "buffer" and NO_INT8_BUFFERS:
${layout_declare_tensor(B, "w", "t_qmat2", "uint", STORAGE, is_scalar_array=True)}
$else:
${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)}

$if NO_INT8_BUFFERS:
${layout_declare_tensor(B, "r", "nchw_4x2", "uint", "buffer")}
$else:
Expand All @@ -35,7 +37,10 @@ $else:
#define BUF_T uint8_t

$if STORAGE == "buffer":
#define UVEC4_T u8vec4
$if NO_INT8_BUFFERS:
#define UVEC4_T uvec4
$else:
#define UVEC4_T u8vec4
$else:
#define UVEC4_T uvec4

Expand All @@ -48,7 +53,7 @@ uint get_second(const BUF_T packed) {
}

uint combine(const uint first, const uint second) {
return (first << 4 | second);
return first * 16 + second;
}

$if NO_INT8_BUFFERS:
Expand Down Expand Up @@ -155,8 +160,12 @@ void main() {

$if STORAGE == "buffer":
int stride = qmat2_sizes.x >> 2;
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
$if NO_INT8_BUFFERS:
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1.x | (out_tex_1.y << 8) | (out_tex_1.z << 16) | (out_tex_1.w << 24);
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2.x | (out_tex_2.y << 8) | (out_tex_2.z << 16) | (out_tex_2.w << 24);
$else:
t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1;
t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2;
$else:
imageStore(t_qmat2, packed_pos.xy, out_tex_1);
imageStore(t_qmat2, ivec2(packed_pos.x, packed_pos.y + 1), out_tex_2);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,6 @@ pack_int4_linear_weight_transposed_interleaved:
STORAGE: buffer
- NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_texture2d
NO_INT8_BUFFERS: true
- NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_buffer
STORAGE: buffer
NO_INT8_BUFFERS: true
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,8 @@ void add_linear_qcs8w_node(
} else {
pcs = {
graph.logical_limits_pc_of(out_W_packed),
graph.sizes_pc_of(mat1_W_packed)};
graph.sizes_pc_of(mat1_W_packed),
graph.sizes_pc_of(q_mat2)};
}

const utils::uvec3 global_wg = {
Expand Down Expand Up @@ -351,7 +352,9 @@ void add_linear_qcsnw_tiled_node(
// Shader params buffers
{},
// Push Constants
{{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}},
{{graph.sizes_pc_of(out),
graph.sizes_pc_of(mat1),
graph.sizes_pc_of(q_mat2)}},
// Specialization Constants
{},
// Resize Args
Expand Down
Loading