Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
255 changes: 255 additions & 0 deletions cpp/tensorrt_llm/common/customAllReduceUtils.h

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -134,8 +134,12 @@ public:
// corresponding CTA has not been launched.
for (int flag_idx = blockIdx.x; flag_idx < kBarrierFlagCount; flag_idx += gridDim.x)
{
st_flag(m_target_flag + flag_idx * NRanks, m_flag_value);
asm volatile(
"st.global.relaxed.sys.b32 [%1], %0;" ::"r"(m_flag_value), "l"(m_target_flag + flag_idx * NRanks));
}
// Single release fence
asm volatile("fence.release.sys;");

while (ld_flag(m_current_flag) == prev_flag(m_flag_value))
{
}
Expand Down
24 changes: 24 additions & 0 deletions cpp/tensorrt_llm/kernels/customAllReduceKernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,30 @@ inline std::string toString(AllReduceFusionOp op)
return oss.str();
}

inline std::ostream& operator<<(std::ostream& os, AllReduceStrategyType op)
{
switch (op)
{
case AllReduceStrategyType::NCCL: os << "NCCL"; break;
case AllReduceStrategyType::MIN_LATENCY: os << "MIN_LATENCY"; break;
case AllReduceStrategyType::UB: os << "UB"; break;
case AllReduceStrategyType::AUTO: os << "AUTO"; break;
case AllReduceStrategyType::ONESHOT: os << "ONESHOT"; break;
case AllReduceStrategyType::TWOSHOT: os << "TWOSHOT"; break;
case AllReduceStrategyType::LOWPRECISION: os << "LOWPRECISION"; break;
case AllReduceStrategyType::MNNVL: os << "MNNVL"; break;
case AllReduceStrategyType::NCCL_SYMMETRIC: os << "NCCL_SYMMETRIC"; break;
}
return os;
}

inline std::string toString(AllReduceStrategyType op)
{
std::ostringstream oss;
oss << op;
return oss.str();
}

struct AllReduceFusionParams
{
AllReduceFusionParams()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,8 @@ __global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
}
}
__syncthreads();
asm volatile("griddepcontrol.wait;");
asm volatile("griddepcontrol.launch_dependents;");

if (warp_idx < 2)
{
Expand All @@ -622,7 +624,6 @@ __global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
mma_computer.issue_mainloop();
mma_computer.epi();
}
asm volatile("griddepcontrol.launch_dependents;");
#endif
}

Expand Down
60 changes: 30 additions & 30 deletions cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -56,12 +56,12 @@ __global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rw(int const op, int const flagoffset, int const firstrank, int const myrank,
int const gpustep, size_t const lineoffset, int const numlines, void** commbuff, int const handleridx)
{
#if __CUDA_ARCH__ >= 900
cudaTriggerProgrammaticLaunchCompletion();
#endif
__shared__ int4* userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
#if __CUDA_ARCH__ >= 900
cudaGridDependencySynchronize();
#endif
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -72,9 +72,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
#if __CUDA_ARCH__ >= 900
cudaGridDependencySynchronize();
#endif
flagptr[physgpu] = reduce_id;
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu + handleridx]);
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
Expand Down Expand Up @@ -130,19 +127,22 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
#if __CUDA_ARCH__ >= 900
cudaTriggerProgrammaticLaunchCompletion();
#endif
} // fp16 inplace reduce kernel (Hopper)

template <typename DType, int RANKS>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_inplace_gpu_rr(int const op, int const flagoffset, int const firstrank, int const myrank,
int const gpustep, size_t const lineoffset, int const numlines, void** commbuff, int const handleridx)
{
#if __CUDA_ARCH__ >= 900
cudaTriggerProgrammaticLaunchCompletion();
#endif
__shared__ int4* userptr[RANKS];
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
#if __CUDA_ARCH__ >= 900
cudaGridDependencySynchronize();
#endif
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -153,9 +153,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
#if __CUDA_ARCH__ >= 900
cudaGridDependencySynchronize();
#endif
flagptr[physgpu] = reduce_id;
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu + handleridx]);
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
Expand Down Expand Up @@ -239,6 +236,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
#if __CUDA_ARCH__ >= 900
cudaTriggerProgrammaticLaunchCompletion();
#endif
} // fp16 inplace reduce kernel (Ampere)

#if __CUDA_ARCH__ >= 900
Expand Down Expand Up @@ -365,7 +365,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
*reduceidptr = reduce_id;
} // fp16 inplace reduce kernel (Hopper) MC

#else
#else // __CUDA_ARCH__ >= 900
template <typename DType, int RANKS, bool DISABLE_FP32_ACC>
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(int const op, int const flagoffset,
int const firstrank, int const myrank, int const gpustep, size_t const lineoffset, int const numlines,
Expand All @@ -375,7 +375,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
asm volatile("brkpt;\n");
}

#endif
#endif // __CUDA_ARCH__ >= 900

#define callranks(x) \
if (ar_nvsize == x) \
Expand Down Expand Up @@ -568,13 +568,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr int SF_VEC_SIZE = 16;
using PackedVec = PackedVec<DType>;
cudaTriggerProgrammaticLaunchCompletion();
float sf = *scale;
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);

int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
cudaGridDependencySynchronize();
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -585,7 +585,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
reduce_id = next_flag(reduce_id);
Expand Down Expand Up @@ -670,6 +669,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand All @@ -684,13 +684,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
constexpr int SF_VEC_SIZE = 16;
using PackedVec = PackedVec<DType>;
cudaTriggerProgrammaticLaunchCompletion();
float sf = *scale;
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);

int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
cudaGridDependencySynchronize();
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -701,7 +701,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
}
Expand Down Expand Up @@ -772,6 +771,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
cudaTriggerProgrammaticLaunchCompletion();
#endif
}

Expand All @@ -784,11 +784,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* mc_ptr_out,
size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
{
cudaTriggerProgrammaticLaunchCompletion();
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
cudaGridDependencySynchronize();
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -799,7 +799,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
reduce_id = next_flag(reduce_id);
Expand Down Expand Up @@ -874,6 +873,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
cudaTriggerProgrammaticLaunchCompletion();
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused

template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
Expand All @@ -883,11 +883,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS,
uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
{
cudaTriggerProgrammaticLaunchCompletion();
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
cudaGridDependencySynchronize();
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -898,7 +898,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
}
Expand Down Expand Up @@ -962,6 +961,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
cudaTriggerProgrammaticLaunchCompletion();
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused oneshot

template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
Expand All @@ -971,13 +971,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
float const eps, int const RANKS, float2* mc_ptr_out, size_t const out_lineoffset, float const* scale,
uint4* residual_in, uint4* residual_out, int res_offset)
{
cudaTriggerProgrammaticLaunchCompletion();
float const sf = 1.f / (*scale);
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);

int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
cudaGridDependencySynchronize();
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -988,7 +988,6 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
reduce_id = next_flag(reduce_id);
Expand Down Expand Up @@ -1066,6 +1065,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
cudaTriggerProgrammaticLaunchCompletion();
} // quant kernel fp16->fp8 twoshot

template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
Expand All @@ -1075,13 +1075,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
float const eps, int const RANKS, uint2* mc_ptr_out, size_t const out_lineoffset, float const* scale,
uint4* residual_in, uint4* residual_out, int res_offset)
{
cudaTriggerProgrammaticLaunchCompletion();
float const sf = 1.f / (*scale);
__shared__ float s_variance;
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);

int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
cudaGridDependencySynchronize();
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -1092,7 +1092,6 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
}
Expand Down Expand Up @@ -1160,6 +1159,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
cudaTriggerProgrammaticLaunchCompletion();
} // quant kernel fp16->fp8 oneshot

template <typename DType, int UNROLL_NLINES>
Expand All @@ -1168,9 +1168,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
int const myrank, int const gpustep, size_t const lineoffset, int const numlines, void** commbuff,
int const handleridx, float4* mc_ptr, int const RANKS, uint4* residual_in, int res_offset)
{
cudaTriggerProgrammaticLaunchCompletion();
int *flagptr, physgpu, targetgpu, *myptr;
int *reduceidptr, reduce_id;
cudaGridDependencySynchronize();
if (threadIdx.x < RANKS)
{
physgpu = myrank * gpustep + firstrank;
Expand All @@ -1181,7 +1181,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
reduce_id = next_flag(*reduceidptr);
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
myptr += blockflagoffset;
cudaGridDependencySynchronize();
flagptr[physgpu] = reduce_id;
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
reduce_id = next_flag(reduce_id);
Expand Down Expand Up @@ -1217,9 +1216,10 @@ __global__ void __launch_bounds__(MAX_THREADS)
}
if (threadIdx.x == 0 && blockIdx.x == 0)
*reduceidptr = reduce_id;
cudaTriggerProgrammaticLaunchCompletion();
} // residual allgather kernel

#else
#else // __CUDA_ARCH__ >= 900
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
__global__ void __launch_bounds__(MAX_THREADS)
userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank,
Expand Down Expand Up @@ -1274,7 +1274,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
asm volatile("brkpt;\n");
}

#endif
#endif // __CUDA_ARCH__ >= 900

#define callranksMC_RMSNORM_QUANT(x) \
if (nlines == x) \
Expand Down
Loading