diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl index 4b1e2b6b7be..2c0336ae96f 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.glsl @@ -19,7 +19,7 @@ #define NWORKERS 8 ${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": +$if WEIGHT_STORAGE == "buffer" and WEIGHT_DTYPE == "uint8": ${define_required_extensions("uint8")} #extension GL_EXT_control_flow_attributes : require @@ -28,7 +28,7 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", WEIGHT_DTYPE, WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_qparams", "float", PARAMS_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)} ${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml index 2d8a979494c..d1aff714a53 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_coop.yaml @@ -7,6 +7,7 @@ linear_qta8a_qga4w_qta8o_coop: parameter_names_with_default_values: DTYPE: int8 + WEIGHT_DTYPE: uint8 OUT_STORAGE: texture3d IN_STORAGE: texture3d WEIGHT_STORAGE: texture2d @@ -24,3 +25,6 @@ linear_qta8a_qga4w_qta8o_coop: - NAME: linear_qta8a_qga4w_qta8o_coop_buffer_texture2d_buffer_int8 OUT_STORAGE: buffer WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_qta8o_coop_texture3d_texture3d_texture2d_int32 + DTYPE: int32 + WEIGHT_DTYPE: uint32 diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl index 7b4f2733066..ad4c20e95d0 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.glsl @@ -16,7 +16,7 @@ #define TILE_ROWS ${TILE_ROWS} ${define_required_extensions(DTYPE)} -$if WEIGHT_STORAGE == "buffer": +$if WEIGHT_STORAGE == "buffer" and WEIGHT_DTYPE == "uint8": ${define_required_extensions("uint8")} #extension GL_EXT_control_flow_attributes : require @@ -25,7 +25,7 @@ layout(std430) buffer; ${layout_declare_tensor(B, "w", "t_out", DTYPE, OUT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_mat1", DTYPE, IN_STORAGE, is_scalar_array=False)} -${layout_declare_tensor(B, "r", "t_qmat2", "uint8", WEIGHT_STORAGE, is_scalar_array=False)} +${layout_declare_tensor(B, "r", "t_qmat2", WEIGHT_DTYPE, WEIGHT_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_qparams", "float", PARAMS_STORAGE, is_scalar_array=False)} ${layout_declare_tensor(B, "r", "t_input_scale", "float", "buffer", is_scalar_array=True)} ${layout_declare_tensor(B, "r", "t_input_zero_point", "int", "buffer", is_scalar_array=True)} diff --git a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml index 9de102cf5f0..55a41129030 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/linear_qta8a_qga4w_qta8o_tiled.yaml @@ -7,6 +7,7 @@ linear_qta8a_qga4w_qta8o_tiled: parameter_names_with_default_values: DTYPE: int8 + WEIGHT_DTYPE: uint8 OUT_STORAGE: texture3d IN_STORAGE: texture3d WEIGHT_STORAGE: texture2d @@ -24,3 +25,6 @@ linear_qta8a_qga4w_qta8o_tiled: - NAME: linear_qta8a_qga4w_qta8o_tiled_buffer_texture2d_buffer_int8 OUT_STORAGE: buffer WEIGHT_STORAGE: buffer + - NAME: linear_qta8a_qga4w_qta8o_tiled_texture3d_texture3d_texture2d_int32 + DTYPE: int32 + WEIGHT_DTYPE: uint32