@@ -64,24 +64,21 @@ void main() {
6464
6565 FLOAT_T outval = FLOAT_T(0.0 );
6666
67- // Initial mat1 tensor idx will be (0, out_tidx.y, out_tidx.z, 0)
6867 int mat1_offset = out_tidx.y * mat1_strides.y + out_tidx.z * qmat2_strides.z;
69- // Initial qmat2 tensor idx wil be (0, out_tidx.x, 0, 0); note that the qmat2
70- // tensor is transposed
71- int qmat2_offset = out_tidx.x * qmat2_strides.y;
68+ int qmat2_offset = out_tidx.x;
7269
7370 // TODO(ssjia): optimize memory access pattern by traversing mat1 x in inner loop
7471 for (int i = 0 ; i < mat1_sizes.x; i++ ) {
7572 const FLOAT_T mat1_val = t_mat1[mat1_offset];
76- const FLOAT_T mat2_val = t_qmat2[qmat2_offset] * scale ;
73+ const FLOAT_T mat2_val = FLOAT_T( t_qmat2[qmat2_offset]) ;
7774
7875 outval += mat1_val * mat2_val;
7976
8077 mat1_offset++ ;
81- qmat2_offset++ ;
78+ qmat2_offset += qmat2_strides.y ;
8279 }
8380
84- t_out[out_bufi] = outval;
81+ t_out[out_bufi] = outval * scale ;
8582}
8683
8784#else // USING_TEXTURE
@@ -97,25 +94,27 @@ void main() {
9794 return ;
9895 }
9996
100- const uint16_t qmat2_pos_y = out_pos.x * uint16_t( 4 ) ;
97+ const uint16_t qmat2_pos_x = out_pos.x;
10198
10299 VEC4_T outtex = VEC4_T(0 );
103100
104101 const VEC4_T scales = load_texel(t_scales, u16vec3(out_pos.x, 0 , 0 ));
105102
103+ VEC4_T mat1_tex;
104+ VEC4_T mat2_tex[4 ];
106105 for (
107106 uint16_t i = uint16_t(0 ), x = uint16_t(0 );
108107 i < uint16_t(mat1_sizes.x);
109108 i += uint16_t(4 ), x++ )
110109 {
111- const VEC4_T mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0 ));
112- const VEC4_T sums = VEC4_T(
113- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y , 0 ))),
114- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(1 ), 0 ))),
115- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(2 ), 0 ))),
116- dot (mat1_tex, load_texel(t_qmat2, u16vec3(x, qmat2_pos_y + uint16_t(3 ), 0 )) ));
117-
118- outtex += sums ;
110+ mat1_tex = load_texel(t_mat1, u16vec3(x, out_pos.y, 0 ));
111+
112+ mat2_tex[ 0 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i , 0 ));
113+ mat2_tex[ 1 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i + uint16_t(1 ), 0 ));
114+ mat2_tex[ 2 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i + uint16_t(2 ), 0 ));
115+ mat2_tex[ 3 ] = load_texel(t_qmat2, u16vec3(out_pos. x, i + uint16_t(3 ), 0 ));
116+
117+ outtex += mat1_tex.x * mat2_tex[ 0 ] + mat1_tex.y * mat2_tex[ 1 ] + mat1_tex.z * mat2_tex[ 2 ] + mat1_tex.w * mat2_tex[ 3 ] ;
119118 }
120119
121120 outtex *= scales;
0 commit comments