diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl index d2e6b4688eb..db627681a3a 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.glsl @@ -18,17 +18,17 @@ ${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": - ${define_required_extensions("int8")} - layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_in", DTYPE, IN_STORAGE, is_scalar_array=False)} -$if QUANT_NBITS == 4: - ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +$if WEIGHT_STORAGE == "buffer": + ${layout_declare_tensor(B, "r", "t_weight", "uint", WEIGHT_STORAGE, is_scalar_array=True)} $else: - ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} + $if QUANT_NBITS == 4: + ${layout_declare_tensor(B, "r", "t_weight", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} + $else: + ${layout_declare_tensor(B, "r", "t_weight", "int8", WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_scales", DTYPE, SCALES_STORAGE, is_scalar_array=False)} @@ -91,22 +91,20 @@ void main() { $if WEIGHT_STORAGE == "buffer": uint qmat2_bufi; uint weight_row_txstride = div4(weight_sizes.x); + uint encoded_weight; // Preload weight tensor for (int r = 0; r < 4; r++) { T qmat2[TILE_TXCOLS * 4]; VEC4_T qmat2_vec4; + uvec4 packed_weight_tex; $if QUANT_NBITS == 4: - $if WEIGHT_STORAGE == "buffer": - u8vec4 packed_weight_tex; - $else: - uvec4 packed_weight_tex; - $for c in range(0, TILE_TXCOLS, 2): $if WEIGHT_STORAGE == "buffer": qmat2_bufi = (pos + r) * weight_row_txstride + weight_txcol; - packed_weight_tex = t_weight[qmat2_bufi + ${c}] + encoded_weight = t_weight[qmat2_bufi + ${c}]; + packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); $else: packed_weight_tex = texelFetch( t_weight, ivec2(weight_txcol + ${c}, pos + r), 0); @@ -126,7 +124,9 @@ void main() { $for c in range(TILE_TXCOLS): $if WEIGHT_STORAGE == "buffer": qmat2_bufi = (pos + r) * weight_row_txstride + out_txcol; - qmat2_vec4 = t_weight[qmat2_bufi + ${c}]; + encoded_weight = t_weight[qmat2_bufi + ${c}]; + packed_weight_tex = uvec4(encoded_weight & 0xFF, (encoded_weight >> 8) & 0xFF, (encoded_weight >> 16) & 0xFF, encoded_weight >> 24); + qmat2_vec4 = VEC4_T(packed_weight_tex); $else: qmat2_vec4 = VEC4_T(texelFetch(t_weight, ivec2(out_txcol + ${c}, pos + r), 0)); $for j in range(4): diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml index bbae284349a..7eba788a1d6 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qcsnw_tiled.yaml @@ -35,8 +35,18 @@ linear_qcsnw_tiled: - NAME: linear_qcs4w_tiled_texture3d_texture3d_texture2d_texture2d_float TILE_TXCOLS: 2 QUANT_NBITS: 4 + - NAME: linear_qcs4w_tiled_texture3d_texture3d_buffer_texture2d_float + TILE_TXCOLS: 2 + QUANT_NBITS: 4 + WEIGHT_STORAGE: buffer - NAME: linear_qcs4w_tiled_buffer_buffer_texture2d_texture2d_float IN_STORAGE: buffer OUT_STORAGE: buffer TILE_TXCOLS: 2 QUANT_NBITS: 4 + - NAME: linear_qcs4w_tiled_buffer_buffer_buffer_texture2d_float + IN_STORAGE: buffer + OUT_STORAGE: buffer + WEIGHT_STORAGE: buffer + TILE_TXCOLS: 2 + QUANT_NBITS: 4 diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl index 0079526c248..18e9b4c7275 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.glsl @@ -12,12 +12,14 @@ $if not NO_INT8_BUFFERS: ${define_required_extensions("uint8")} -$if STORAGE == "buffer": - ${define_required_extensions("int8")} layout(std430) buffer; -${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)} +$if STORAGE == "buffer" and NO_INT8_BUFFERS: + ${layout_declare_tensor(B, "w", "t_qmat2", "uint", STORAGE, is_scalar_array=True)} +$else: + ${layout_declare_tensor(B, "w", "t_qmat2", "uint8", STORAGE, is_scalar_array=False)} + $if NO_INT8_BUFFERS: ${layout_declare_tensor(B, "r", "nchw_4x2", "uint", "buffer")} $else: @@ -35,7 +37,10 @@ $else: #define BUF_T uint8_t $if STORAGE == "buffer": - #define UVEC4_T u8vec4 + $if NO_INT8_BUFFERS: + #define UVEC4_T uvec4 + $else: + #define UVEC4_T u8vec4 $else: #define UVEC4_T uvec4 @@ -48,7 +53,7 @@ uint get_second(const BUF_T packed) { } uint combine(const uint first, const uint second) { - return (first << 4 | second); + return first * 16 + second; } $if NO_INT8_BUFFERS: @@ -155,8 +160,12 @@ void main() { $if STORAGE == "buffer": int stride = qmat2_sizes.x >> 2; - t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1; - t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2; + $if NO_INT8_BUFFERS: + t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1.x | (out_tex_1.y << 8) | (out_tex_1.z << 16) | (out_tex_1.w << 24); + t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2.x | (out_tex_2.y << 8) | (out_tex_2.z << 16) | (out_tex_2.w << 24); + $else: + t_qmat2[packed_pos.y * stride + packed_pos.x] = out_tex_1; + t_qmat2[(packed_pos.y + 1) * stride + packed_pos.x] = out_tex_2; $else: imageStore(t_qmat2, packed_pos.xy, out_tex_1); imageStore(t_qmat2, ivec2(packed_pos.x, packed_pos.y + 1), out_tex_2); diff --git a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml index 145f4301f14..6bddb4c62cd 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/pack_int4_linear_weight_transposed_interleaved.yaml @@ -14,3 +14,6 @@ pack_int4_linear_weight_transposed_interleaved: STORAGE: buffer - NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_texture2d NO_INT8_BUFFERS: true + - NAME: pack_int4_linear_weight_transposed_interleaved_nobitw8buffer_buffer + STORAGE: buffer + NO_INT8_BUFFERS: true diff --git a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp index 971291cb11f..ad65ebfe82d 100644 --- a/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/QuantizedLinearQCSNW.cpp @@ -225,7 +225,8 @@ void add_linear_qcs8w_node( } else { pcs = { graph.logical_limits_pc_of(out_W_packed), - graph.sizes_pc_of(mat1_W_packed)}; + graph.sizes_pc_of(mat1_W_packed), + graph.sizes_pc_of(q_mat2)}; } const utils::uvec3 global_wg = { @@ -351,7 +352,9 @@ void add_linear_qcsnw_tiled_node( // Shader params buffers {}, // Push Constants - {{graph.sizes_pc_of(out), graph.sizes_pc_of(mat1)}}, + {{graph.sizes_pc_of(out), + graph.sizes_pc_of(mat1), + graph.sizes_pc_of(q_mat2)}}, // Specialization Constants {}, // Resize Args