@@ -138,8 +138,8 @@ Status DP4AMatMulNBitsProgram::GenerateShaderCode(ShaderHelper& shader) const {
138138 shader.MainFunctionBody () << R"MAIN_FN(
139139 // During the load phase we use all 256 threads to load 64 rows of A/B.
140140 // For each row we load tile_size_k_vec (2) vectorized elements, which are 32 elements of K.
141- let a_global_base = workgroup_id.x * tile_size;
142- let b_global_base = workgroup_id.y * tile_size;
141+ let a_global_base = u32(workgroup_idx / uniforms.num_N_tile) * tile_size;
142+ let b_global_base = (workgroup_idx % uniforms.num_N_tile) * tile_size;
143143 let load_AorB = u32(local_idx/128);
144144 let load_row = u32((local_idx%128)/2);
145145 let load_col = u32(local_idx%2);
@@ -275,11 +275,11 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
275275
276276 constexpr uint32_t kTileSize = 64 ;
277277 TensorShape reshaped_y_shape{1 , M, N / kVec4Components };
278+ uint32_t num_M_tile = (M + kTileSize - 1 ) / kTileSize ;
279+ uint32_t num_N_tile = (N + kTileSize - 1 ) / kTileSize ;
278280 DP4AMatMulNBitsProgram mul_program{block_size};
279281 mul_program.SetWorkgroupSize (256 );
280- mul_program.SetDispatchGroupSize (
281- (M + kTileSize - 1 ) / kTileSize ,
282- (N + kTileSize - 1 ) / kTileSize , 1 );
282+ mul_program.SetDispatchGroupSize (num_M_tile * num_N_tile);
283283 mul_program.AddInputs ({{&a_quant, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(kVec4Components )},
284284 {&a_scale, ProgramTensorMetadataDependency::TypeAndRank, 1 },
285285 {b, ProgramTensorMetadataDependency::TypeAndRank, static_cast <int >(kVec2Components * kU32Components )},
@@ -288,7 +288,8 @@ Status ApplyDP4AMatrixMatMulNBits(const Tensor* a, const Tensor* b, const Tensor
288288 {static_cast <uint32_t >(N)},
289289 {static_cast <uint32_t >(K)},
290290 {static_cast <uint32_t >(K / 8 )},
291- {static_cast <uint32_t >(K / 16 )}})
291+ {static_cast <uint32_t >(K / 16 )},
292+ {num_N_tile}})
292293 .AddOutput ({y, ProgramTensorMetadataDependency::TypeAndRank, reshaped_y_shape, static_cast <int >(kVec4Components )})
293294 .CacheHint (" Block" + std::to_string (block_size));
294295 return context.RunProgram (mul_program);
0 commit comments