Skip to content

Commit 7b7ecc0

Browse files
committed
metal : handle some edge cases when threadgroup size is not a power of 2
ggml-ci
1 parent 97819a0 commit 7b7ecc0

File tree

1 file changed

+6
-0
lines changed

1 file changed

+6
-0
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
24502450
nth *= 2;
24512451
}
24522452

2453+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
24532454
nth = MIN(nth, ne00);
24542455

24552456
ggml_metal_kargs_sum_rows args = {
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
37803781
nth *= 2;
37813782
}
37823783

3784+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
37833785
nth = MIN(nth, ne00/4);
37843786

37853787
ggml_metal_kargs_rms_norm args = {
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
38163818
nth *= 2;
38173819
}
38183820

3821+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
38193822
nth = MIN(nth, ne00/4);
38203823

38213824
ggml_metal_kargs_l2_norm args = {
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
38883891
nth *= 2;
38893892
}
38903893

3894+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
38913895
nth = MIN(nth, ne00/4);
38923896

38933897
ggml_metal_kargs_norm args = {
@@ -4986,6 +4990,8 @@ static bool ggml_metal_encode_node(
49864990
nth *= 2;
49874991
}
49884992

4993+
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
4994+
49894995
// when rows are small, we can batch them together in a single threadgroup
49904996
int nrptg = 1;
49914997

0 commit comments

Comments
 (0)