@@ -185,15 +185,16 @@ Status MatMulProgram::MakeMatMulPackedVec4Source(ShaderHelper& shader,
185185 << " " << (inner_elements_size == 3 ? " " : " acc[i] = BCached3 * ACached.w + acc[i];" ) << " \n "
186186 << " }\n " ;
187187 }
188- shader.MainFunctionBody () << " workgroupBarrier();\n "
189- << " }\n " ; // main for loop
188+ shader.MainFunctionBody ()
189+ << " }\n "
190+ << " workgroupBarrier();\n "
191+ << " }\n " ; // main for loop
190192
191193 // Write the results to the output buffer
192194 shader.MainFunctionBody ()
193195 << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n "
194196 << " mm_write(batch, globalRow + innerRow, globalCol, acc[innerRow]);\n "
195- << " }\n "
196- << " }\n " ;
197+ << " }\n " ;
197198
198199 return Status::OK ();
199200}
@@ -217,8 +218,8 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader,
217218
218219 const auto tile_a_outer = workgroup_size_y * elements_per_thread_y;
219220 const auto tile_b_outer = workgroup_size_x * elements_per_thread_x;
220- const auto tile_a_width = tile_inner;
221- const auto tile_a_height = tile_a_outer;
221+ const auto tile_a_width = transpose_a ? tile_a_outer : tile_inner;
222+ const auto tile_a_height = transpose_a ? tile_inner : tile_a_outer;
222223
223224 if (!(tile_a_height % workgroup_size_y == 0 && tile_a_width % workgroup_size_x == 0 && tile_inner % workgroup_size_y == 0 )) {
224225 return ORT_MAKE_STATUS (ONNXRUNTIME, INVALID_ARGUMENT,
@@ -243,7 +244,7 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader,
243244 << (nullptr != batch_dims ? " let batchIndices = " + batch_dims->OffsetToIndices (" u32(batch)" ) + " ;\n " : " " )
244245 << " let num_tiles = (uniforms.dim_inner - 1) / tileInner + 1;\n "
245246 << " var kStart = 0;\n "
246- << " var acc: array<vec4 <" << data_type << " >, rowPerThread>;\n " ;
247+ << " var acc: array<array <" << data_type << " , colPerThread >, rowPerThread>;\n " ;
247248
248249 if (sequentially_access_by_threads) {
249250 shader.MainFunctionBody () << " let localRow = i32(local_id.y);\n "
@@ -277,7 +278,7 @@ Status MatMulProgram::MakeMatMulPackedSource(ShaderHelper& shader,
277278 << " BCached[inner] = mm_Bsub[k][localCol + inner * " << workgroup_size_x << " ];\n "
278279 << " }\n "
279280 << " for (var innerRow = 0; innerRow < rowPerThread; innerRow = innerRow + 1) {\n "
280- << " let ACached = " << (transpose_a ? " mm_Asub[k][localCol + innerRow * " + std::to_string (workgroup_size_y) + " ];" : " mm_Asub[localRow + innerRow * " + std::to_string (workgroup_size_y) + " ][k];" ) << " \n "
281+ << " let ACached = " << (transpose_a ? " mm_Asub[k][localRow + innerRow * " + std::to_string (workgroup_size_y) + " ];" : " mm_Asub[localRow + innerRow * " + std::to_string (workgroup_size_y) + " ][k];" ) << " \n "
281282 << " for (var innerCol = 0; innerCol < colPerThread; innerCol = innerCol + 1) {\n "
282283 << " acc[innerRow][innerCol] = acc[innerRow][innerCol] +\n "
283284 << " ACached * BCached[innerCol];\n "
0 commit comments