diff --git a/custom_ops/gpu_ops/noaux_tc.cu b/custom_ops/gpu_ops/noaux_tc.cu index 19a9e380f8..302d85eb2f 100644 --- a/custom_ops/gpu_ops/noaux_tc.cu +++ b/custom_ops/gpu_ops/noaux_tc.cu @@ -48,6 +48,7 @@ std::vector NoauxTc(paddle::Tensor& scores, n_group, topk_group, topk, + true, routed_scaling_factor, stream); diff --git a/custom_ops/gpu_ops/noauxtc_kernel.h b/custom_ops/gpu_ops/noauxtc_kernel.h index e8a3f45080..392dbfe3b1 100644 --- a/custom_ops/gpu_ops/noauxtc_kernel.h +++ b/custom_ops/gpu_ops/noauxtc_kernel.h @@ -25,6 +25,23 @@ constexpr unsigned FULL_WARP_MASK = 0xffffffff; constexpr int32_t BLOCK_SIZE = 512; constexpr int32_t NUM_WARPS_PER_BLOCK = BLOCK_SIZE / WARP_SIZE; +template +__device__ inline T_OUT cuda_cast(T_IN val) { + return val; +} + +template <> +__device__ inline float cuda_cast(__nv_bfloat16 val) { + return __bfloat162float(val); +} + +template +__device__ inline T neg_inf() { + // cuda::std::numeric_limits::infinity() returns `0` for [T=bf16 or fp16] + // so we need to cast from fp32 + return cuda_cast(-cuda::std::numeric_limits::infinity()); +} + namespace warp_topk { template @@ -41,10 +58,21 @@ constexpr __host__ __device__ bool isPowerOf2(T v) { } template -__device__ bool is_better_than(T val, T baseline) { +__forceinline__ __device__ bool is_better_than(T val, T baseline) { return (val > baseline && greater) || (val < baseline && !greater); } +template +__forceinline__ __device__ bool is_better_than(T val, T baseline, idxT index, + idxT baseline_index) { + bool res = (val > baseline && greater) || (val < baseline && !greater); + if (val == baseline) { + res = (index < baseline_index && greater) || + (index < baseline_index && !greater); + } + return res; +} + template int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { int64_t cache_topk = (sizeof(T) + sizeof(idxT)) * num_of_warp * k; @@ -53,7 +81,8 @@ int calc_smem_size_for_block_wide(int num_of_warp, int64_t k) { round_up_to_multiple_of<256>(n * sizeof(T)) + n * sizeof(idxT)); } -template +template struct BitonicMerge { // input should be a bitonic sequence, and sort it to be a monotonic sequence __device__ static void merge(T* __restrict__ val_arr, @@ -67,7 +96,15 @@ struct BitonicMerge { int const other_i = i + stride; T& val = val_arr[i]; T& other_val = val_arr[other_i]; - if ((val > other_val && ascending) || (val < other_val && !ascending)) { + bool is_better; + if constexpr (is_stable) { + is_better = is_better_than(val, other_val, idx_arr[i], + idx_arr[other_i]); + } else { + is_better = is_better_than(val, other_val); + } + + if (is_better) { T tmp = val; val = other_val; other_val = tmp; @@ -78,13 +115,14 @@ struct BitonicMerge { } } - BitonicMerge::merge(val_arr, idx_arr); - BitonicMerge::merge(val_arr + arr_len / 2, - idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); + BitonicMerge::merge( + val_arr + arr_len / 2, idx_arr + arr_len / 2); } }; -template +template struct BitonicSort { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { @@ -92,15 +130,16 @@ struct BitonicSort { static_assert(size >= 2 * WARP_SIZE); constexpr int arr_len = size / WARP_SIZE; - BitonicSort::sort(val_arr, idx_arr); - BitonicSort::sort(val_arr + arr_len / 2, - idx_arr + arr_len / 2); - BitonicMerge::merge(val_arr, idx_arr); + BitonicSort::sort(val_arr, idx_arr); + BitonicSort::sort( + val_arr + arr_len / 2, idx_arr + arr_len / 2); + BitonicMerge::merge( + val_arr, idx_arr); } }; -template -struct BitonicSort<32, ascending, T, idxT> { +template +struct BitonicSort<32, ascending, T, idxT, is_stable> { __device__ static void sort(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; @@ -114,19 +153,37 @@ struct BitonicSort<32, ascending, T, idxT> { T other = __shfl_xor_sync(FULL_WARP_MASK, *val_arr, stride); idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, *idx_arr, stride); - if (*val_arr != other && (*val_arr > other) != (reverse != is_second)) { + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) != + (reverse != is_second); + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) != + (reverse != is_second); + } + } else { + is_better = (*val_arr != other && + (*val_arr > other) != (reverse != is_second)); + } + if (is_better) { *val_arr = other; *idx_arr = other_idx; } } } - BitonicMerge<32, ascending, T, idxT>::merge(val_arr, idx_arr); + BitonicMerge<32, ascending, ascending, T, idxT, is_stable>::merge(val_arr, + idx_arr); } }; -template -struct BitonicMerge<32, ascending, T, idxT> { +template +struct BitonicMerge<32, ascending, reverse, T, idxT, is_stable> { __device__ static void merge(T* __restrict__ val_arr, idxT* __restrict__ idx_arr) { int const lane = threadIdx.x % WARP_SIZE; @@ -136,7 +193,24 @@ struct BitonicMerge<32, ascending, T, idxT> { T other = __shfl_xor_sync(FULL_WARP_MASK, val, stride); idxT& idx = *idx_arr; idxT other_idx = __shfl_xor_sync(FULL_WARP_MASK, idx, stride); - if (val != other && ((val > other) == (ascending != is_second))) { + + bool is_better; + if constexpr (is_stable) { + if constexpr (ascending) { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr < other_idx))) == + (reverse != is_second); // for min + } else { + is_better = ((*val_arr > other) || + ((*val_arr == other) && (*idx_arr > other_idx))) == + (reverse != is_second); // for max + } + } else { + is_better = + (val != other && ((val > other) == (ascending != is_second))); + } + + if (is_better) { val = other; idx = other_idx; } @@ -144,34 +218,42 @@ struct BitonicMerge<32, ascending, T, idxT> { } }; -template +template class WarpSort { -public: + public: __device__ WarpSort(idxT k, T dummy) : lane_(threadIdx.x % WARP_SIZE), k_(k), dummy_(dummy) { static_assert(capacity >= WARP_SIZE && isPowerOf2(capacity)); for (int i = 0; i < max_arr_len_; ++i) { val_arr_[i] = dummy_; + idx_arr_[i] = 0; } } // load and merge k sorted values __device__ void load_sorted(T const* __restrict__ in, - idxT const* __restrict__ in_idx, - idxT start) { + idxT const* __restrict__ in_idx, idxT start) { idxT idx = start + WARP_SIZE - 1 - lane_; for (int i = max_arr_len_ - 1; i >= 0; --i, idx += WARP_SIZE) { if (idx < start + k_) { T t = in[idx]; - if (is_better_than(t, val_arr_[i])) { + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(t, val_arr_[i], in_idx[idx], idx_arr_[i]); + } else { + is_better = is_better_than(t, val_arr_[i]); + } + if (is_better) { val_arr_[i] = t; idx_arr_[i] = in_idx[idx]; } } } - BitonicMerge::merge(val_arr_, idx_arr_); + BitonicMerge::merge( + val_arr_, idx_arr_); } __device__ void dump(T* __restrict__ out, idxT* __restrict__ out_idx) const { @@ -193,7 +275,7 @@ class WarpSort { } } -protected: + protected: static constexpr int max_arr_len_ = capacity / WARP_SIZE; T val_arr_[max_arr_len_]; @@ -205,11 +287,11 @@ class WarpSort { }; // end class WarpSort -template -class WarpSelect : public WarpSort { -public: +template +class WarpSelect : public WarpSort { + public: __device__ WarpSelect(idxT k, T dummy) - : WarpSort(k, dummy), + : WarpSort(k, dummy), k_th_(dummy), k_th_lane_((k - 1) % WARP_SIZE) { extern __shared__ char smem_buf[]; // extern __shared__ T smem_buf[]; @@ -234,7 +316,13 @@ class WarpSelect : public WarpSort { } __device__ void add(T val, idxT idx) { - bool do_add = is_better_than(val, k_th_); + bool do_add; + if constexpr (is_stable) { + do_add = is_better_than(val, k_th_, idx, k_th_idx_); + } else { + do_add = is_better_than(val, k_th_); + } + uint32_t mask = __ballot_sync(FULL_WARP_MASK, do_add); if (mask == 0) { return; @@ -271,37 +359,52 @@ class WarpSelect : public WarpSort { __syncthreads(); } -private: + private: __device__ void set_k_th_() { k_th_ = __shfl_sync(FULL_WARP_MASK, val_arr_[max_arr_len_ - 1], k_th_lane_); + if constexpr (is_stable) { + k_th_idx_ = + __shfl_sync(FULL_WARP_MASK, idx_arr_[max_arr_len_ - 1], k_th_lane_); + } } __device__ void merge_buf_(T val, idxT idx) { - BitonicSort::sort(&val, &idx); + BitonicSort::sort(&val, &idx); T& old = val_arr_[max_arr_len_ - 1]; - if (is_better_than(val, old)) { + + bool is_better; + if constexpr (is_stable) { + is_better = + is_better_than(val, old, idx, idx_arr_[max_arr_len_ - 1]); + } else { + is_better = is_better_than(val, old); + } + + if (is_better) { old = val; idx_arr_[max_arr_len_ - 1] = idx; } - BitonicMerge::merge(val_arr_, idx_arr_); + BitonicMerge::merge( + val_arr_, idx_arr_); set_k_th_(); } - using WarpSort::max_arr_len_; - using WarpSort::val_arr_; - using WarpSort::idx_arr_; - using WarpSort::lane_; - using WarpSort::k_; - using WarpSort::dummy_; + using WarpSort::max_arr_len_; + using WarpSort::val_arr_; + using WarpSort::idx_arr_; + using WarpSort::lane_; + using WarpSort::k_; + using WarpSort::dummy_; T* val_smem_; idxT* idx_smem_; int smem_buf_len_ = 0; T k_th_; + idxT k_th_idx_; int const k_th_lane_; }; // end class WarpSelect } // namespace warp_topk @@ -313,8 +416,8 @@ __device__ void topk_with_k2(T* output, int32_t const lane_id, int const num_experts_per_group) { // Get the top2 per thread - T largest = cuda::std::numeric_limits::min(); - T second_largest = cuda::std::numeric_limits::min(); + T largest = neg_inf(); + T second_largest = neg_inf(); if (num_experts_per_group > WARP_SIZE) { for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) { @@ -368,8 +471,14 @@ __global__ void topk_with_k2_kernel(T* output, cg::thread_block block = cg::this_thread_block(); cg::thread_block_tile<32> tile = cg::tiled_partition<32>(block); +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); +#endif topk_with_k2(output, input, tile, lane_id, num_experts_per_group); } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } template @@ -385,6 +494,7 @@ __global__ void group_idx_and_topk_idx_kernel( int64_t const topk, int64_t const num_experts, int64_t const num_experts_per_group, + bool const renormalize, double routed_scaling_factor) { int32_t warp_id = threadIdx.x / WARP_SIZE; int32_t lane_id = threadIdx.x % WARP_SIZE; @@ -403,19 +513,29 @@ __global__ void group_idx_and_topk_idx_kernel( extern __shared__ char smem_buf[]; // NOTE: reuse the shared memory here to // store the target topk idx - int32_t* s_topk_idx = reinterpret_cast(smem_buf) + warp_id * topk; + int32_t* s_topk_idx = reinterpret_cast(smem_buf); T* s_topk_value = reinterpret_cast(s_topk_idx + NUM_WARPS_PER_BLOCK * topk) + warp_id * topk; + s_topk_idx += warp_id * topk; - T value = cuda::std::numeric_limits::min(); - T topk_group_value = cuda::std::numeric_limits::min(); + T value = neg_inf(); + T topk_group_value = neg_inf(); int32_t num_equalto_topkth_group; - if ((n_group > topk_group) && (case_id < num_tokens)) { +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.wait;"); // I think all prolog can be put before + // acqbulk because it's ptr arithmetic +#endif + + if (case_id < num_tokens) { // calculate group_idx int32_t target_num_min = WARP_SIZE - n_group + topk_group; - if (lane_id < n_group) { + if (lane_id < n_group && + (isfinite(cuda_cast( + group_scores[lane_id])))) // The check is necessary to avoid + // abnormal input + { value = group_scores[lane_id]; } @@ -426,22 +546,23 @@ __global__ void group_idx_and_topk_idx_kernel( __syncwarp(); // Ensure all threads have valid data before reduction topk_group_value = cg::reduce(tile, value, cg::greater()); if (value == topk_group_value) { - value = cuda::std::numeric_limits::min(); + value = neg_inf(); } pre_count_equal_to_top_value = count_equal_to_top_value; count_equal_to_top_value = __popc(__ballot_sync( - FULL_WARP_MASK, (value == cuda::std::numeric_limits::min()))); + FULL_WARP_MASK, (value == neg_inf()))); } num_equalto_topkth_group = target_num_min - pre_count_equal_to_top_value; } __syncthreads(); - warp_topk::WarpSelect - queue((int32_t)topk, cuda::std::numeric_limits::min()); + warp_topk::WarpSelect + queue((int32_t)topk, neg_inf()); int count_equalto_topkth_group = 0; - bool if_proceed_next_topk = (topk_group_value != cuda::std::numeric_limits::min()); - if (case_id < num_tokens) { + bool if_proceed_next_topk = (topk_group_value != neg_inf()); + if (case_id < num_tokens && if_proceed_next_topk) { for (int i_group = 0; i_group < n_group; i_group++) { if ((group_scores[i_group] > topk_group_value) || ((group_scores[i_group] == topk_group_value) && @@ -449,9 +570,11 @@ __global__ void group_idx_and_topk_idx_kernel( int32_t offset = i_group * num_experts_per_group; for (int32_t i = lane_id; i < align_num_experts_per_group; i += WARP_SIZE) { - T candidates = i < num_experts_per_group - ? scores_with_bias[offset + i] - : cuda::std::numeric_limits::min(); + T candidates = + (i < num_experts_per_group) && isfinite(cuda_cast( + scores_with_bias[offset + i])) + ? scores_with_bias[offset + i] + : neg_inf(); queue.add(candidates, offset + i); } if (group_scores[i_group] == topk_group_value) { @@ -469,7 +592,7 @@ __global__ void group_idx_and_topk_idx_kernel( // Load the valid score value // Calculate the summation float topk_sum = 1e-20; - if (case_id < num_tokens) { + if (case_id < num_tokens && if_proceed_next_topk) { for (int i = lane_id; i < warp_topk::round_up_to_multiple_of(topk); i += WARP_SIZE) { @@ -478,33 +601,45 @@ __global__ void group_idx_and_topk_idx_kernel( if (i < topk) { s_topk_value[i] = value; } - topk_sum += reduce(tile, value, cg::plus()); + topk_sum += reduce(tile, cuda_cast(value), cg::plus()); } } __syncthreads(); - if (case_id < num_tokens) { + + if (case_id < num_tokens && if_proceed_next_topk) { for (int i = lane_id; i < num_experts; i += WARP_SIZE) { scores[i] = 0; } } - __threadfence(); - __syncthreads(); + __syncwarp(); if (case_id < num_tokens) { - for (int i = lane_id; i < topk; i += WARP_SIZE) { - float value = s_topk_value[i] / topk_sum * routed_scaling_factor; - scores[s_topk_idx[i]] = value; - if (if_proceed_next_topk) { + if (if_proceed_next_topk) { + for (int i = lane_id; i < topk; i += WARP_SIZE) { + float value; + if (renormalize) { + value = cuda_cast(s_topk_value[i]) / topk_sum * + routed_scaling_factor; + } else { + value = cuda_cast(s_topk_value[i]) * routed_scaling_factor; + } + scores[s_topk_idx[i]] = value; topk_indices[i] = s_topk_idx[i]; - topk_values[i] = static_cast(value); + topk_values[i] = cuda_cast(value); } - else { + } else { + for (int i = lane_id; i < topk; i += WARP_SIZE) { topk_indices[i] = i; - topk_values[i] = static_cast(1.0f / topk); + topk_values[i] = cuda_cast(1.0f / topk); } } + // Note: when if_proceed_next_topk==false, choose the first 8 experts as the + // default result. } +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif } template @@ -518,17 +653,24 @@ void invokeNoAuxTc(T* scores, int64_t const n_group, int64_t const topk_group, int64_t const topk, + bool const renormalize, double const routed_scaling_factor, cudaStream_t const stream) { int64_t num_cases = num_tokens * n_group; int64_t topk_with_k2_num_blocks = (num_cases - 1) / NUM_WARPS_PER_BLOCK + 1; - topk_with_k2_kernel<<>>( - group_scores, - scores_with_bias, - num_tokens, - num_cases, - n_group, - num_experts / n_group); + auto* kernel_instance1 = &topk_with_k2_kernel; + cudaLaunchConfig_t config; + config.gridDim = topk_with_k2_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = 0; + config.stream = stream; + cudaLaunchAttribute attrs[1]; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance1, group_scores, scores_with_bias, + num_tokens, num_cases, n_group, num_experts / n_group); int64_t topk_with_k_group_num_blocks = (num_tokens - 1) / NUM_WARPS_PER_BLOCK + 1; @@ -536,21 +678,19 @@ void invokeNoAuxTc(T* scores, warp_topk::calc_smem_size_for_block_wide(NUM_WARPS_PER_BLOCK, topk); - group_idx_and_topk_idx_kernel<<>>(scores, - group_scores, - topk_values, - topk_indices, - scores_with_bias, - num_tokens, - n_group, - topk_group, - topk, - num_experts, - num_experts / n_group, - routed_scaling_factor); + auto* kernel_instance2 = &group_idx_and_topk_idx_kernel; + config.gridDim = topk_with_k_group_num_blocks; + config.blockDim = BLOCK_SIZE; + config.dynamicSmemBytes = dynamic_smem_in_bytes; + config.stream = stream; + attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attrs[0].val.programmaticStreamSerializationAllowed = false; + config.numAttrs = 1; + config.attrs = attrs; + cudaLaunchKernelEx(&config, kernel_instance2, scores, group_scores, + topk_values, topk_indices, scores_with_bias, num_tokens, + n_group, topk_group, topk, num_experts, + num_experts / n_group, renormalize, routed_scaling_factor); } #define INSTANTIATE_NOAUX_TC(T, IdxT) \ @@ -564,6 +704,7 @@ void invokeNoAuxTc(T* scores, int64_t const n_group, \ int64_t const topk_group, \ int64_t const topk, \ + bool const renormalize, \ double const routed_scaling_factor, \ cudaStream_t const stream);