@@ -52,19 +52,26 @@ layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
5252#define FLOAT_T float
5353#endif
5454
55- FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
56- const FLOAT_T scale = t_scales[out_idx.x];
55+ void main() {
56+ const int out_bufi = int (gl_GlobalInvocationID.x);
57+ if (out_bufi >= out_numel) {
58+ return ;
59+ }
60+
61+ const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0 );
62+
63+ const FLOAT_T scale = t_scales[out_tidx.x];
5764
5865 FLOAT_T outval = FLOAT_T(0.0 );
5966
60- // Initial mat1 tensor idx will be (0, out_idx .y, out_idx .z, 0)
61- int mat1_offset = out_idx .y * mat1_strides.y + out_idx .z * qmat2_strides.z;
62- // Initial qmat2 tensor idx wil be (0, out_idx .x, 0, 0); note that the qmat2
67+ // Initial mat1 tensor idx will be (0, out_tidx .y, out_tidx .z, 0)
68+ int mat1_offset = out_tidx .y * mat1_strides.y + out_tidx .z * qmat2_strides.z;
69+ // Initial qmat2 tensor idx wil be (0, out_tidx .x, 0, 0); note that the qmat2
6370 // tensor is transposed
64- int qmat2_offset = out_idx .x * qmat2_strides.y;
71+ int qmat2_offset = out_tidx .x * qmat2_strides.y;
6572
66- // TODO(ssjia): optimize memory access pattern by traversing K in inner loop
67- for (int i = 0 ; i < K ; i++ ) {
73+ // TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
74+ for (int i = 0 ; i < mat1_sizes.x ; i++ ) {
6875 const FLOAT_T mat1_val = t_mat1[mat1_offset];
6976 const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale;
7077
@@ -74,33 +81,32 @@ FLOAT_T q_8w_linear(const ivec4 out_idx, const int K) {
7481 qmat2_offset++ ;
7582 }
7683
77- return outval;
78- }
79-
80- void main() {
81- const int out_bufi = int (gl_GlobalInvocationID.x);
82- if (out_bufi >= out_numel) {
83- return ;
84- }
85-
86- const ivec4 out_tidx = bufi_to_tidx(out_bufi, out_strides, 0 );
87-
88- t_out[out_bufi] = q_8w_linear(out_tidx, mat1_sizes.x);
84+ t_out[out_bufi] = outval;
8985}
9086
9187#else // USING_TEXTURE
9288
9389#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
9490
95- VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) {
91+ void main() {
92+ const u16vec2 out_pos = u16vec2(
93+ gl_GlobalInvocationID.x / out_limits.y,
94+ gl_GlobalInvocationID.x % out_limits.y);
95+ if (out_pos.x >= out_limits.x) {
96+ return ;
97+ }
98+
9699 const uint16_t qmat2_pos_y = out_pos.x * uint16_t(4 );
97100
98101 VEC4_T outtex = VEC4_T(0 );
99102
100- const u16vec3 scales_pos = u16vec3(out_pos.x, 0 , 0 );
101- const VEC4_T scales = load_texel(t_scales, scales_pos);
103+ const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0 , 0 ));
102104
103- for (uint16_t i = uint16_t(0 ), x = uint16_t(0 ); i < K; i += uint16_t(4 ), x++ ) {
105+ for (
106+ uint16_t i = uint16_t(0 ), x = uint16_t(0 );
107+ i < uint16_t(mat1_sizes.x);
108+ i += uint16_t(4 ), x++ )
109+ {
104110 const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0 ));
105111 const VEC4_T sums = VEC4_T(
106112 dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y, 0 ))),
@@ -112,19 +118,6 @@ VEC4_T q_8w_linear(const u16vec2 out_pos, const uint16_t K) {
112118 }
113119
114120 outtex *= scales;
115-
116- return outtex;
117- }
118-
119- void main() {
120- const u16vec2 out_pos = u16vec2(
121- gl_GlobalInvocationID.x / out_limits.y,
122- gl_GlobalInvocationID.x % out_limits.y);
123- if (out_pos.x >= out_limits.x) {
124- return ;
125- }
126-
127- VEC4_T outtex = q_8w_linear(out_pos, uint16_t(mat1_sizes.x));
128121 write_texel(t_out, u16vec3(out_pos, 0 ), outtex);
129122}
130123
0 commit comments