@@ -67,6 +67,7 @@ std::string NAInt8MatMulKernel::createSource() const noexcept {
6767 source.SetValue (" GROUP_M" , std::to_string (groupM));
6868 source.SetValue (" GROUP_N" , std::to_string (groupN));
6969 source.SetValue (" IO_TYPE" , ioPrecision.name ());
70+ source.SetValue (" M_VALUE" , loadM ? " M_dynamic" : " M" );
7071 source += R"(
7172#include <metal_stdlib>
7273#include <metal_tensor>
@@ -229,11 +230,11 @@ kernel void int8_matmul(
229230)" ;
230231 if (loadM) {
231232 source += R"(
232- const uint M = loadM_buf[0];
233+ const uniform< uint> M_dynamic = make_uniform( loadM_buf[0]) ;
233234)" ;
234235 }
235236 source += R"(
236- const uint M_tiles = (M + {{BLOCK_M}} - 1) / {{BLOCK_M}};
237+ const uint M_tiles = ({{M_VALUE}} + {{BLOCK_M}} - 1) / {{BLOCK_M}};
237238 const uint N_tiles = (N + {{BLOCK_N}} - 1) / {{BLOCK_N}};
238239 const uint M_tile_bits = M_tiles <= 1 ? 0 : 32 - clz(M_tiles - 1);
239240 const uint N_tile_bits = N_tiles <= 1 ? 0 : 32 - clz(N_tiles - 1);
@@ -245,12 +246,12 @@ kernel void int8_matmul(
245246 }
246247
247248 const uint M_block_start = tgid.y * {{BLOCK_M}};
248- const uint M_block_size = min((uint){{BLOCK_M}}, M - M_block_start);
249+ const uint M_block_size = min((uint){{BLOCK_M}}, {{M_VALUE}} - M_block_start);
249250 const uint N_block_start = tgid.x * {{BLOCK_N}};
250251 const uint N_block_size = min((uint){{BLOCK_N}}, N - N_block_start);
251252 const uint M_group_start = {{GROUP_M}} ? (M_block_start / {{GROUP_M}}) * {{GROUP_M}} : M_block_start;
252253 const uint M_group_offset = M_block_start - M_group_start;
253- const uint M_group_size = M - M_group_start;
254+ const uint M_group_size = {{M_VALUE}} - M_group_start;
254255 const uint N_group_start = {{GROUP_N}} ? (N_block_start / {{GROUP_N}}) * {{GROUP_N}} : N_block_start;
255256 const uint N_group_offset = N_block_start - N_group_start;
256257 const uint N_group_size = N - N_group_start;
@@ -284,7 +285,7 @@ kernel void int8_matmul(
284285 source += R"(
285286 auto A = tensor<device int8_t, dextents<int32_t, 2>, tensor_inline>(A_buf, dextents<int32_t, 2>(K, M_group_size));
286287 auto B = tensor<device int8_t, dextents<int32_t, 2>, tensor_inline>(B_buf, dextents<int32_t, 2>(K, N_group_size));
287- if (N_block_start + {{BLOCK_N}} - 1 < N && M_block_start + {{BLOCK_M}} - 1 < M ) {
288+ if (N_block_start + {{BLOCK_N}} - 1 < N && M_block_start + {{BLOCK_M}} - 1 < {{M_VALUE}} ) {
288289 constexpr auto matmul_descriptor = matmul2d_descriptor(
289290 {{BLOCK_M}},
290291 {{BLOCK_N}},
0 commit comments