Skip to content

Commit d20b02d

Browse files
refactor: Simplify shared memory sizing
Branch: GraniteFourPerf Signed-off-by: Gabe Goodhart <[email protected]> Co-Authored-By: Georgi Gerganov <[email protected]>
1 parent c3711e1 commit d20b02d

File tree

1 file changed

+4
-11
lines changed

1 file changed

+4
-11
lines changed

ggml/src/ggml-metal/ggml-metal.m

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3171,19 +3171,12 @@ static int ggml_metal_encode_node(
31713171
[encoder setBytes:&args length:sizeof(args) atIndex:8];
31723172

31733173
// One shared memory bucket for each simd group in the threadgroup
3174+
// NOTE: Metal kernels require the buffer size to be multiple of 16 bytes
3175+
// https://developer.apple.com/documentation/metal/mtlcomputecommandencoder/1443142-setthreadgroupmemorylength
31743176
if (d_state >= 32) {
3175-
const int64_t shmem_size = d_state / 32;
3176-
3177-
// The final simd_sum won't work if the number of simd groups is
3178-
// larger than the size of a single simd group. If this case is
3179-
// hit at some point, the logic in the second simd_sum could be
3180-
// expanded to handle this with one more sequential simd_sum to
3181-
// collapse simd group sums another time.
3182-
GGML_ASSERT(shmem_size <= 32);
3183-
3184-
// One thread pre element in d_state
3177+
GGML_ASSERT((int64_t)(d_state / 32) <= 32);
3178+
const int64_t shmem_size = 32;
31853179
GGML_ASSERT(d_state <= (int64_t)pipeline.maxTotalThreadsPerThreadgroup);
3186-
31873180
[encoder setThreadgroupMemoryLength:(shmem_size)*sizeof(float) atIndex:0];
31883181
}
31893182

0 commit comments

Comments
 (0)