|
16 | 16 | #include <math_constants.h> |
17 | 17 | #include <mma.h> |
18 | 18 |
|
| 19 | +#if CCCL_VERSION >= 2008002 |
| 20 | +#include <cuda/std/functional> |
| 21 | +#define CUB_REDUCTIONOP_MAX \ |
| 22 | + cuda::maximum<> {} |
| 23 | +#else |
| 24 | +#define CUB_REDUCTIONOP_MAX cub::Max() |
| 25 | +#endif |
| 26 | + |
19 | 27 | #define HLF_MAX 65504 |
20 | 28 | #define TH 1024 |
21 | 29 | #define NUM 4 |
@@ -365,7 +373,7 @@ __global__ void kQuantizeBlockwise( |
365 | 373 | for (int j = 0; j < NUM_PER_TH; j++) |
366 | 374 | local_abs_max = fmaxf(local_abs_max, fabsf((float)vals[j])); |
367 | 375 |
|
368 | | - local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, cub::Max(), valid_items); |
| 376 | + local_abs_max = BlockReduce(reduce).Reduce(local_abs_max, CUB_REDUCTIONOP_MAX, valid_items); |
369 | 377 |
|
370 | 378 | if (threadIdx.x == 0) { |
371 | 379 | smem_absmax_value[0] = 1.0f / local_abs_max; |
@@ -951,12 +959,12 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b |
951 | 959 | } |
952 | 960 |
|
953 | 961 | __syncthreads(); |
954 | | - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); |
| 962 | + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, CUB_REDUCTIONOP_MAX, valid_items); |
955 | 963 | __syncthreads(); |
956 | | - local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, cub::Max(), valid_items); |
| 964 | + local_max_s2 = BlockReduce(temp_storage.reduce).Reduce(local_max_s2, CUB_REDUCTIONOP_MAX, valid_items); |
957 | 965 | if (unorm != NULL) { |
958 | 966 | __syncthreads(); |
959 | | - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); |
| 967 | + local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items); |
960 | 968 | } |
961 | 969 |
|
962 | 970 | if (threadIdx.x == 0) { |
@@ -1162,13 +1170,13 @@ __global__ void __launch_bounds__(NUM_THREADS, 2) kPreconditionOptimizerStatic8b |
1162 | 1170 | } |
1163 | 1171 |
|
1164 | 1172 | __syncthreads(); |
1165 | | - local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, cub::Max(), valid_items); |
| 1173 | + local_max_s1 = BlockReduce(temp_storage.reduce).Reduce(local_max_s1, CUB_REDUCTIONOP_MAX, valid_items); |
1166 | 1174 | if (threadIdx.x == 0) { |
1167 | 1175 | atomicMax(&new_max1[0], local_max_s1); |
1168 | 1176 | } |
1169 | 1177 | if (unorm != NULL) { |
1170 | 1178 | __syncthreads(); |
1171 | | - local_unorm = BlockReduce(temp_storage.reduce).Reduce(local_unorm, cub::Sum(), valid_items); |
| 1179 | + local_unorm = BlockReduce(temp_storage.reduce).Sum(local_unorm, valid_items); |
1172 | 1180 | if (threadIdx.x == 0) { |
1173 | 1181 | atomicAdd(&unorm[0], local_unorm); |
1174 | 1182 | } |
@@ -1473,11 +1481,11 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit2StateBlockwise( |
1473 | 1481 | } |
1474 | 1482 |
|
1475 | 1483 | // reduce: 2.51/1.60 -> 2.67/1.69 |
1476 | | - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); |
1477 | | - new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, cub::Max()); |
| 1484 | + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX); |
| 1485 | + new_local_abs_max2 = BlockReduce2(reduce2).Reduce(new_local_abs_max2, CUB_REDUCTIONOP_MAX); |
1478 | 1486 |
|
1479 | 1487 | if (OPTIMIZER == ADEMAMIX) { |
1480 | | - new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, cub::Max()); |
| 1488 | + new_local_abs_max3 = BlockReduce3(reduce3).Reduce(new_local_abs_max3, CUB_REDUCTIONOP_MAX); |
1481 | 1489 | } |
1482 | 1490 |
|
1483 | 1491 | if (threadIdx.x == 0) { |
@@ -1686,7 +1694,7 @@ __launch_bounds__(256, 3) __global__ void kOptimizerStatic8bit1StateBlockwise( |
1686 | 1694 | } |
1687 | 1695 |
|
1688 | 1696 | // reduce: 2.51/1.60 -> 2.67/1.69 |
1689 | | - new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, cub::Max()); |
| 1697 | + new_local_abs_max1 = BlockReduce1(reduce1).Reduce(new_local_abs_max1, CUB_REDUCTIONOP_MAX); |
1690 | 1698 |
|
1691 | 1699 | if (threadIdx.x == 0) |
1692 | 1700 | smem_exchange1[0] = new_local_abs_max1; |
@@ -1792,7 +1800,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ |
1792 | 1800 | } |
1793 | 1801 |
|
1794 | 1802 | // Reduce thread-local absmax across the block. |
1795 | | - const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); |
| 1803 | + const TReduction row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols); |
1796 | 1804 | if (threadIdx.x == 0) { |
1797 | 1805 | // Save our block's absmax to shared memory for the quantization step. |
1798 | 1806 | rowStats[row_id] = smem_row_absmax = row_absmax; |
@@ -1847,7 +1855,7 @@ __launch_bounds__(1024, BNB_MAX_THREADS_PER_SM / 1024) __global__ |
1847 | 1855 |
|
1848 | 1856 | // Reduce thread-local absmax across the block. |
1849 | 1857 | // TODO: Consider algorithm BLOCK_REDUCE_RAKING_COMMUTATIVE_ONLY |
1850 | | - const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, cub::Max(), cols); |
| 1858 | + const float row_absmax = BlockReduceT(temp_storage).Reduce(row_local_absmax, CUB_REDUCTIONOP_MAX, cols); |
1851 | 1859 | if (threadIdx.x == 0) { |
1852 | 1860 | // Save our block's absmax to shared memory for the quantization step. |
1853 | 1861 | rowStats[row_id] = row_absmax; |
|
0 commit comments