Skip to content

Commit cab1def

Browse files
committed
Set M to be uniform.
1 parent 92496b6 commit cab1def

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

lib/nnc/mfa/kernels/NAInt8MatMulKernel.cpp

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ std::string NAInt8MatMulKernel::createSource() const noexcept {
6767
source.SetValue("GROUP_M", std::to_string(groupM));
6868
source.SetValue("GROUP_N", std::to_string(groupN));
6969
source.SetValue("IO_TYPE", ioPrecision.name());
70+
source.SetValue("M_VALUE", loadM ? "M_dynamic" : "M");
7071
source += R"(
7172
#include <metal_stdlib>
7273
#include <metal_tensor>
@@ -229,11 +230,11 @@ kernel void int8_matmul(
229230
)";
230231
if (loadM) {
231232
source += R"(
232-
const uint M = loadM_buf[0];
233+
const uniform<uint> M_dynamic = make_uniform(loadM_buf[0]);
233234
)";
234235
}
235236
source += R"(
236-
const uint M_tiles = (M + {{BLOCK_M}} - 1) / {{BLOCK_M}};
237+
const uint M_tiles = ({{M_VALUE}} + {{BLOCK_M}} - 1) / {{BLOCK_M}};
237238
const uint N_tiles = (N + {{BLOCK_N}} - 1) / {{BLOCK_N}};
238239
const uint M_tile_bits = M_tiles <= 1 ? 0 : 32 - clz(M_tiles - 1);
239240
const uint N_tile_bits = N_tiles <= 1 ? 0 : 32 - clz(N_tiles - 1);
@@ -245,12 +246,12 @@ kernel void int8_matmul(
245246
}
246247
247248
const uint M_block_start = tgid.y * {{BLOCK_M}};
248-
const uint M_block_size = min((uint){{BLOCK_M}}, M - M_block_start);
249+
const uint M_block_size = min((uint){{BLOCK_M}}, {{M_VALUE}} - M_block_start);
249250
const uint N_block_start = tgid.x * {{BLOCK_N}};
250251
const uint N_block_size = min((uint){{BLOCK_N}}, N - N_block_start);
251252
const uint M_group_start = {{GROUP_M}} ? (M_block_start / {{GROUP_M}}) * {{GROUP_M}} : M_block_start;
252253
const uint M_group_offset = M_block_start - M_group_start;
253-
const uint M_group_size = M - M_group_start;
254+
const uint M_group_size = {{M_VALUE}} - M_group_start;
254255
const uint N_group_start = {{GROUP_N}} ? (N_block_start / {{GROUP_N}}) * {{GROUP_N}} : N_block_start;
255256
const uint N_group_offset = N_block_start - N_group_start;
256257
const uint N_group_size = N - N_group_start;
@@ -284,7 +285,7 @@ kernel void int8_matmul(
284285
source += R"(
285286
auto A = tensor<device int8_t, dextents<int32_t, 2>, tensor_inline>(A_buf, dextents<int32_t, 2>(K, M_group_size));
286287
auto B = tensor<device int8_t, dextents<int32_t, 2>, tensor_inline>(B_buf, dextents<int32_t, 2>(K, N_group_size));
287-
if (N_block_start + {{BLOCK_N}} - 1 < N && M_block_start + {{BLOCK_M}} - 1 < M) {
288+
if (N_block_start + {{BLOCK_N}} - 1 < N && M_block_start + {{BLOCK_M}} - 1 < {{M_VALUE}}) {
288289
constexpr auto matmul_descriptor = matmul2d_descriptor(
289290
{{BLOCK_M}},
290291
{{BLOCK_N}},

lib/nnc/mfa/kernels/NAMatMulKernel.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ kernel void matmul(device {{MEMORY_NAME_A}} *A_buf [[buffer(0)]],
212212
)";
213213
if (loadM) {
214214
source += R"(
215-
const uint M = loadM[0];
215+
const uniform<uint> M = make_uniform(loadM[0]);
216216
)";
217217
}
218218
source += R"(
@@ -600,7 +600,7 @@ kernel void reduce_sum_2(device {{MEMORY_NAME_C}}2 *A_buf [[buffer(0)]],
600600
)";
601601
if (loadM) {
602602
source += R"(
603-
const uint M = loadM[0];
603+
const uniform<uint> M = make_uniform(loadM[0]);
604604
)";
605605
}
606606
source += R"(
@@ -633,7 +633,7 @@ kernel void reduce_sum(device {{MEMORY_NAME_C}} *A_buf [[buffer(0)]],
633633
)";
634634
if (loadM) {
635635
source += R"(
636-
const uint M = loadM[0];
636+
const uniform<uint> M = make_uniform(loadM[0]);
637637
)";
638638
}
639639
source += R"(

0 commit comments

Comments
 (0)