11#version 450
2- #extension GL_EXT_shader_explicit_arithmetic_types : require
2+ #extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
33
44#include "mul_mat_vec_base.comp"
55
@@ -40,9 +40,9 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
4040
4141 [[unroll]] for (uint n = 0; n < num_rows; ++n) {
4242 const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
43- f16vec2 d = data_a[ib0 + i].d;
44- const FLOAT_TYPE dall = d.x;
45- const FLOAT_TYPE dmin = d.y;
43+ vec2 d = vec2( data_a[ib0 + i].d) ;
44+ const FLOAT_TYPE dall = FLOAT_TYPE( d.x) ;
45+ const FLOAT_TYPE dmin = FLOAT_TYPE( d.y) ;
4646
4747 uint32_t s0_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 0];
4848 uint32_t s4_u32 = data_a_packed32[ib0 + i].scales[s_offset / 4 + 1];
@@ -63,14 +63,14 @@ void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
6363 uvec2 qs16 = uvec2(unpack8(qs16_u16));
6464
6565 [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
66- B_TYPE_VEC2 b0 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0] ;
67- B_TYPE_VEC2 b16 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8] ;
68- B_TYPE_VEC2 b32 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16];
69- B_TYPE_VEC2 b48 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24];
70- B_TYPE_VEC2 b64 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32];
71- B_TYPE_VEC2 b80 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40];
72- B_TYPE_VEC2 b96 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48];
73- B_TYPE_VEC2 b112 = data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56];
66+ vec2 b0 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]) ;
67+ vec2 b16 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]) ;
68+ vec2 b32 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]) ;
69+ vec2 b48 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]) ;
70+ vec2 b64 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]) ;
71+ vec2 b80 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]) ;
72+ vec2 b96 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]) ;
73+ vec2 b112 = vec2( data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]) ;
7474
7575 FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
7676 FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
0 commit comments