|
14 | 14 |
|
15 | 15 | using namespace ggml_cuda_mma; |
16 | 16 |
|
| 17 | +typedef void (* fattn_new_mma_kernel_t)( |
| 18 | + const char * __restrict__ Q, |
| 19 | + const char * __restrict__ K, |
| 20 | + const char * __restrict__ V, |
| 21 | + const char * __restrict__ mask, |
| 22 | + const char * __restrict__ sinks, |
| 23 | + const int * __restrict__ KV_max, |
| 24 | + float * __restrict__ dst, |
| 25 | + float2 * __restrict__ dst_meta, |
| 26 | + const float scale, |
| 27 | + const float max_bias, |
| 28 | + const float m0, |
| 29 | + const float m1, |
| 30 | + const float softcap, |
| 31 | + const uint32_t n_head_log2, |
| 32 | + const int ne00, |
| 33 | + const int ne01, |
| 34 | + const int ne02, |
| 35 | + const int ne03, |
| 36 | + const int ne10, |
| 37 | + const int ne11, |
| 38 | + const int ne12, |
| 39 | + const int ne13, |
| 40 | + const int ne31, |
| 41 | + const int nb31, |
| 42 | + const int nb01, |
| 43 | + const int nb02, |
| 44 | + const int nb03, |
| 45 | + const int nb11, |
| 46 | + const int nb12, |
| 47 | + const int nb13, |
| 48 | + const int nb21, |
| 49 | + const int nb22, |
| 50 | + const int nb23, |
| 51 | + const int ne0, |
| 52 | + const int ne1, |
| 53 | + const int ne2, |
| 54 | + const int ne3); |
| 55 | + |
| 56 | + |
17 | 57 | typedef tile<16, 8, half2> tile_A; |
18 | 58 | typedef tile< 8, 8, half2> tile_B; |
19 | 59 | typedef tile<16, 8, half2> tile_B_16; |
@@ -43,37 +83,37 @@ struct fattn_mma_f16_config; |
43 | 83 | // Perhaps the 256 head size needs a closer look |
44 | 84 | // to see if this implementation is better. |
45 | 85 | // |
46 | | -template <> |
47 | | -struct fattn_mma_f16_config< 64, 64> { |
48 | | - static constexpr int nbatch_fa = 64; |
49 | | - static constexpr int nwarps_max = 4; |
50 | | - static constexpr bool Q_in_reg = true; |
51 | | - static constexpr int nstages_target = 2; |
52 | | - |
53 | | - static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
54 | | - return 32; |
55 | | - } |
56 | | - |
57 | | - static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
58 | | - return 32; |
59 | | - } |
60 | | - |
61 | | - static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
62 | | - return 32; |
63 | | - } |
64 | | - |
65 | | - static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
66 | | - return 32; |
67 | | - } |
68 | | - |
69 | | - static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
70 | | - return 32; |
71 | | - } |
72 | | - |
73 | | - static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
74 | | - return 32; |
75 | | - } |
76 | | -}; |
| 86 | +//template <> |
| 87 | +//struct fattn_mma_f16_config< 64, 64> { |
| 88 | +// static constexpr int nbatch_fa = 64; |
| 89 | +// static constexpr int nwarps_max = 4; |
| 90 | +// static constexpr bool Q_in_reg = true; |
| 91 | +// static constexpr int nstages_target = 2; |
| 92 | +// |
| 93 | +// static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 94 | +// return 32; |
| 95 | +// } |
| 96 | +// |
| 97 | +// static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 98 | +// return 32; |
| 99 | +// } |
| 100 | +// |
| 101 | +// static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 102 | +// return 32; |
| 103 | +// } |
| 104 | +// |
| 105 | +// static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 106 | +// return 32; |
| 107 | +// } |
| 108 | +// |
| 109 | +// static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 110 | +// return 32; |
| 111 | +// } |
| 112 | +// |
| 113 | +// static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 114 | +// return 32; |
| 115 | +// } |
| 116 | +//}; |
77 | 117 | // |
78 | 118 | //template <> |
79 | 119 | //struct fattn_mma_f16_config< 80, 80> { |
@@ -243,6 +283,38 @@ struct fattn_mma_f16_config< 64, 64> { |
243 | 283 | // } |
244 | 284 | //}; |
245 | 285 |
|
| 286 | +template <> |
| 287 | +struct fattn_mma_f16_config<192, 128> { |
| 288 | + static constexpr int nbatch_fa = 64; |
| 289 | + static constexpr int nwarps_max = 4; |
| 290 | + static constexpr bool Q_in_reg = true; |
| 291 | + static constexpr int nstages_target = 1; |
| 292 | + |
| 293 | + static int get_nbatch_K2_host(const int /*cc*/, const int /*ncols*/) { |
| 294 | + return 64; |
| 295 | + } |
| 296 | + |
| 297 | + static constexpr __device__ int get_nbatch_K2_device(int /*ncols*/) { |
| 298 | + return 64; |
| 299 | + } |
| 300 | + |
| 301 | + static int get_nbatch_V2_host(const int /*cc*/, const int /*ncols*/) { |
| 302 | + return 64; |
| 303 | + } |
| 304 | + |
| 305 | + static constexpr __device__ int get_nbatch_V2_device(int /*ncols*/) { |
| 306 | + return 64; |
| 307 | + } |
| 308 | + |
| 309 | + static int get_nbatch_combine_host(const int /*cc*/, const int /*ncols*/) { |
| 310 | + return 64; |
| 311 | + } |
| 312 | + |
| 313 | + static constexpr __device__ int get_nbatch_combine_device(int /*ncols*/) { |
| 314 | + return 64; |
| 315 | + } |
| 316 | +}; |
| 317 | + |
246 | 318 | template <> |
247 | 319 | struct fattn_mma_f16_config<576, 512> { |
248 | 320 | static constexpr int nbatch_fa = 32; |
@@ -1287,6 +1359,7 @@ static __global__ void flash_attn_ext_f16( |
1287 | 1359 | const char * __restrict__ V, |
1288 | 1360 | const char * __restrict__ mask, |
1289 | 1361 | const char * __restrict__ sinks, |
| 1362 | + const int * __restrict__ KV_max, |
1290 | 1363 | float * __restrict__ dst, |
1291 | 1364 | float2 * __restrict__ dst_meta, |
1292 | 1365 | const float scale, |
@@ -1377,8 +1450,11 @@ static __global__ void flash_attn_ext_f16( |
1377 | 1450 |
|
1378 | 1451 | const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; |
1379 | 1452 |
|
1380 | | - const int kb0_start_kernel = kb0_start * kb_niter; |
1381 | | - const int kb0_stop_kernel = kb0_stop * kb_niter; |
| 1453 | + int kb0_start_kernel = kb0_start * kb_niter; |
| 1454 | + int kb0_stop_kernel = kb0_stop * kb_niter; |
| 1455 | + if (KV_max) { |
| 1456 | + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa); |
| 1457 | + } |
1382 | 1458 |
|
1383 | 1459 | constexpr bool is_fixup = false; // All but (potentially) the last iterations write their data to dst rather than the fixup buffer. |
1384 | 1460 | if (kb0_start == 0) { |
@@ -1417,8 +1493,11 @@ static __global__ void flash_attn_ext_f16( |
1417 | 1493 |
|
1418 | 1494 | const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, channel, n_head_log2, m0, m1) : 1.0f; |
1419 | 1495 |
|
1420 | | - const int kb0_start_kernel = kb0_start * kb_niter; |
1421 | | - const int kb0_stop_kernel = kb0_stop * kb_niter; |
| 1496 | + int kb0_start_kernel = kb0_start * kb_niter; |
| 1497 | + int kb0_stop_kernel = kb0_stop * kb_niter; |
| 1498 | + if (KV_max) { |
| 1499 | + kb0_stop_kernel = min(kb0_stop_kernel, KV_max[jt] / c::nbatch_fa); |
| 1500 | + } |
1422 | 1501 |
|
1423 | 1502 | constexpr bool is_fixup = true; // Last index writes its data to fixup buffer to avoid data races with other blocks. |
1424 | 1503 | constexpr bool needs_fixup = false; |
@@ -1574,9 +1653,68 @@ static __global__ void flash_attn_combine_results_new( |
1574 | 1653 | dst[blockIdx.z*D + tid] = VKQ_numerator / VKQ_denominator; |
1575 | 1654 | } |
1576 | 1655 |
|
| 1656 | +template<int width = WARP_SIZE> |
| 1657 | +static __device__ __forceinline__ int warp_reduce_all(int x) { |
| 1658 | + if constexpr (width == WARP_SIZE) { //ggml_cuda_get_physical_warp_size()) { |
| 1659 | + return __all_sync(0xffffffff, x); |
| 1660 | + } else { |
| 1661 | +#pragma unroll |
| 1662 | + for (int offset = width/2; offset > 0; offset >>= 1) { |
| 1663 | + x = __shfl_xor_sync(0xffffffff, x, offset, width) && x; |
| 1664 | + } |
| 1665 | + return x; |
| 1666 | + } |
| 1667 | +} |
| 1668 | + |
| 1669 | +template <int ncols1> |
| 1670 | +__launch_bounds__(FATTN_KQ_STRIDE/2, 1) |
| 1671 | +static __global__ void flash_attn_mask_to_KV_max( |
| 1672 | + const half2 * __restrict__ mask, int * __restrict__ KV_min_max, const int ne30, const int s31, const int s33) { |
| 1673 | + const int ne31 = gridDim.x; |
| 1674 | + const int tid = threadIdx.x; |
| 1675 | + const int sequence = blockIdx.y; |
| 1676 | + const int jt = blockIdx.x; |
| 1677 | + |
| 1678 | + mask += sequence*s33 + jt*ncols1*s31; |
| 1679 | + |
| 1680 | + __shared__ int buf_iw[WARP_SIZE]; |
| 1681 | + if (tid < WARP_SIZE) { |
| 1682 | + buf_iw[tid] = 1; |
| 1683 | + } |
| 1684 | + __syncthreads(); |
| 1685 | + |
| 1686 | + int KV_max_sj = (ne30 - 1) * FATTN_KQ_STRIDE; |
| 1687 | + for (; KV_max_sj >= 0; KV_max_sj -= FATTN_KQ_STRIDE) { |
| 1688 | + int all_inf = 1; |
| 1689 | + |
| 1690 | +#pragma unroll |
| 1691 | + for (int j = 0; j < ncols1; ++j) { |
| 1692 | + const float2 tmp = __half22float2(mask[j*s31 + KV_max_sj/2 + tid]); |
| 1693 | + all_inf = all_inf && int(isinf(tmp.x)) && int(isinf(tmp.y)); |
| 1694 | + } |
| 1695 | + |
| 1696 | + all_inf = warp_reduce_all(all_inf); |
| 1697 | + if (tid % WARP_SIZE == 0) { |
| 1698 | + buf_iw[tid / WARP_SIZE] = all_inf; |
| 1699 | + } |
| 1700 | + __syncthreads(); |
| 1701 | + all_inf = buf_iw[tid % WARP_SIZE]; |
| 1702 | + __syncthreads(); |
| 1703 | + all_inf = warp_reduce_all(all_inf); |
| 1704 | + |
| 1705 | + if (!all_inf) { |
| 1706 | + break; |
| 1707 | + } |
| 1708 | + } |
| 1709 | + |
| 1710 | + if (threadIdx.x == 0) { |
| 1711 | + KV_min_max[sequence*ne31 + jt] = KV_max_sj + FATTN_KQ_STRIDE; |
| 1712 | + } |
| 1713 | +} |
| 1714 | + |
1577 | 1715 | template <int DV, int ncols1, int ncols2> |
1578 | 1716 | static void launch_fattn_new_mma( |
1579 | | - ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, |
| 1717 | + ggml_backend_cuda_context & ctx, ggml_tensor * dst, fattn_new_mma_kernel_t fattn_kernel, const int nwarps, const size_t nbytes_shared, |
1580 | 1718 | const int KQ_row_granularity, const bool need_f16_K, const bool need_f16_V, const bool stream_k, const int warp_size = WARP_SIZE |
1581 | 1719 | ) { |
1582 | 1720 | constexpr int ncols = ncols1 * ncols2; |
@@ -1605,10 +1743,15 @@ static void launch_fattn_new_mma( |
1605 | 1743 | cudaStream_t main_stream = ctx.stream(); |
1606 | 1744 | const int id = ggml_cuda_get_device(); |
1607 | 1745 | const int cc = ggml_cuda_info().devices[id].cc; |
1608 | | - const int nsm = ggml_cuda_info().devices[id].nsm; |
| 1746 | + const int nsm_actual = ggml_cuda_info().devices[id].nsm; |
| 1747 | + int nsm = 1; while (nsm*2 <= nsm_actual) nsm *= 2; |
| 1748 | + |
| 1749 | + if (Q->ne[1] == 1 && K->ne[1] <= 4096 && nsm > 32) nsm /= 2; |
| 1750 | + if (Q->ne[1] >= 32 && K->ne[1] >= 4096) nsm *= 2; |
1609 | 1751 |
|
1610 | 1752 | ggml_cuda_pool_alloc<half> K_f16(pool); |
1611 | 1753 | ggml_cuda_pool_alloc<half> V_f16(pool); |
| 1754 | + ggml_cuda_pool_alloc<int> KV_max(pool); |
1612 | 1755 | ggml_cuda_pool_alloc<float> dst_tmp(pool); |
1613 | 1756 | ggml_cuda_pool_alloc<float2> dst_tmp_meta(pool); |
1614 | 1757 |
|
@@ -1675,6 +1818,25 @@ static void launch_fattn_new_mma( |
1675 | 1818 | const int ntiles_x = ((Q->ne[1] + ncols1 - 1) / ncols1); |
1676 | 1819 | const int ntiles_total = ntiles_x * (Q->ne[2] / ncols2) * Q->ne[3]; |
1677 | 1820 |
|
| 1821 | + // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. |
| 1822 | + // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or |
| 1823 | + // multiple sequences of possibly different lengths. |
| 1824 | + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { |
| 1825 | + const int s31 = mask->nb[1] / sizeof(half2); |
| 1826 | + const int s33 = mask->nb[3] / sizeof(half2); |
| 1827 | + |
| 1828 | + const dim3 blocks_num_KV_max(ntiles_x, Q->ne[3], 1); |
| 1829 | + const dim3 block_dim_KV_max(FATTN_KQ_STRIDE/2, 1, 1); |
| 1830 | + |
| 1831 | + const int ne_KV_max = blocks_num_KV_max.x*blocks_num_KV_max.y; |
| 1832 | + const int iter_k = K->ne[1] / FATTN_KQ_STRIDE; |
| 1833 | + |
| 1834 | + KV_max.alloc(ne_KV_max); |
| 1835 | + flash_attn_mask_to_KV_max<ncols1><<<blocks_num_KV_max, block_dim_KV_max, 0, main_stream>>> |
| 1836 | + ((const half2 *) mask->data, KV_max.ptr, iter_k, s31, s33); |
| 1837 | + CUDA_CHECK(cudaGetLastError()); |
| 1838 | + } |
| 1839 | + |
1678 | 1840 | const dim3 block_dim(warp_size, nwarps, 1); |
1679 | 1841 | int max_blocks_per_sm = 1; // Max. number of active blocks limited by occupancy. |
1680 | 1842 | CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor(&max_blocks_per_sm, fattn_kernel, block_dim.x * block_dim.y * block_dim.z, nbytes_shared)); |
@@ -1761,6 +1923,7 @@ static void launch_fattn_new_mma( |
1761 | 1923 | V_data, |
1762 | 1924 | mask ? ((const char *) mask->data) : nullptr, |
1763 | 1925 | sinks ? ((const char *)sinks->data) : nullptr, |
| 1926 | + KV_max.get(), |
1764 | 1927 | !stream_k && parallel_blocks > 1 ? dst_tmp.ptr : (float *) KQV->data, dst_tmp_meta.ptr, |
1765 | 1928 | scale, max_bias, m0, m1, logit_softcap, n_head_log2, |
1766 | 1929 | Q->ne[0], Q->ne[1], Q->ne[2], Q->ne[3], |
@@ -1837,7 +2000,7 @@ static void ggml_cuda_flash_attn_ext_mma_f16_case(ggml_backend_cuda_context & ct |
1837 | 2000 | float logit_softcap; |
1838 | 2001 | memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); |
1839 | 2002 |
|
1840 | | - fattn_kernel_t fattn_kernel; |
| 2003 | + fattn_new_mma_kernel_t fattn_kernel; |
1841 | 2004 | if (logit_softcap == 0.0f) { |
1842 | 2005 | constexpr bool use_logit_softcap = false; |
1843 | 2006 | fattn_kernel = flash_attn_ext_f16<DKQ, DV, ncols1, ncols2, nwarps, ntiles, use_logit_softcap, mla>; |
@@ -1944,8 +2107,16 @@ void ggml_cuda_flash_attn_ext_mma_new(ggml_backend_cuda_context & ctx, ggml_tens |
1944 | 2107 | GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); |
1945 | 2108 | const int gqa_ratio = Q->ne[2] / K->ne[2]; |
1946 | 2109 |
|
1947 | | - if (K->ne[0] == 64 && V->ne[0] == 64) { |
1948 | | - ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); |
| 2110 | + //if (K->ne[0] == 64 && V->ne[0] == 64) { |
| 2111 | + // ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<64, 64>(ctx, dst); |
| 2112 | + // return; |
| 2113 | + //} |
| 2114 | + if (K->ne[0] == 192 && V->ne[0] == 128) { |
| 2115 | + GGML_ASSERT(Q->ne[0] == 192); |
| 2116 | + GGML_ASSERT(gqa_ratio == 1); |
| 2117 | + //ggml_cuda_flash_attn_ext_mma_f16_switch_ncols2<192, 128>(ctx, dst); |
| 2118 | + // Reduce compile time |
| 2119 | + ggml_cuda_flash_attn_ext_mma_f16_switch_ncols1<192, 128, 1>(ctx, dst); |
1949 | 2120 | return; |
1950 | 2121 | } |
1951 | 2122 | GGML_ASSERT(Q->ne[0] == 576 && K->ne[0] == 576 && V->ne[0] == 512); |
|
0 commit comments