@@ -48,7 +48,7 @@ void dequantize_f16(device const half4x4 * src, short il, thread type4x4 & reg)
4848
4949template <typename type4>
5050void dequantize_f16_t4 (device const half4 * src, short il, thread type4 & reg) {
51- reg = (type4)(*(src + il ));
51+ reg = (type4)(*(src));
5252}
5353
5454#if defined(GGML_METAL_USE_BF16)
@@ -59,7 +59,7 @@ void dequantize_bf16(device const bfloat4x4 * src, short il, thread type4x4 & re
5959
6060template <typename type4>
6161void dequantize_bf16_t4 (device const bfloat4 * src, short il, thread type4 & reg) {
62- reg = (type4)(*(src + il ));
62+ reg = (type4)(*(src));
6363}
6464#endif
6565
@@ -3644,7 +3644,7 @@ kernel void kernel_flash_attn_ext_vec(
36443644 const short DK4 = DK/4 ;
36453645 const short DV4 = DV/4 ;
36463646 const short NW = N_SIMDWIDTH;
3647- const short NL = NW/NE; // note: this can be adjusted to support different head sizes simdgroup work loads
3647+ const short NL = NW/NE; // note: this can be adjusted to support different head sizes and simdgroup work loads
36483648 const short SH = 2 *C; // shared memory per simdgroup
36493649
36503650 const short T = DK + nsg*SH; // shared memory size per query in (half)
@@ -3656,7 +3656,7 @@ kernel void kernel_flash_attn_ext_vec(
36563656 threadgroup half * sm = (threadgroup half *) (shmem_f16 + sgitg*SH + C + Q*DK); // scratch buffer for mask
36573657 threadgroup o4_t * sr4 = (threadgroup o4_t *) (shmem_f16 + sgitg*DV + Q*T); // scratch buffer for the results
36583658
3659- // store the result for all queries in local memory in 8x8 matrices (the O matrix from the paper)
3659+ // store the result for all queries in local memory (the O matrix from the paper)
36603660 o4_t lo[DV4/NL];
36613661
36623662 // load heads from Q to shared memory
@@ -3756,7 +3756,7 @@ kernel void kernel_flash_attn_ext_vec(
37563756 mqk += dot ((float4) mk, (float4) sq4[i]);
37573757 }
37583758
3759- static_assert (NE > 1 , " NE must be > 1" );
3759+ static_assert (NE > 1 , " NE must be > 1" ); // note: not sure why NE == 1 fails
37603760
37613761 // simdgroup reduce (NE = 4)
37623762 // [ 0 .. 7] -> [ 0]
0 commit comments