Skip to content

Commit 62dc3b7

Browse files
dc3671fredricz-20070104
authored andcommitted
[https://nvbugs/5545522][fix] move PREEXIT in UB kernels to fix accuracy issue (NVIDIA#8318)
Signed-off-by: Zhenhuan Chen <[email protected]> Signed-off-by: Mike Iovine <[email protected]> Signed-off-by: FredricZ-2007 <[email protected]>
1 parent 413d9b8 commit 62dc3b7

File tree

2 files changed

+32
-31
lines changed

2 files changed

+32
-31
lines changed

cpp/tensorrt_llm/kernels/dsv3MinLatencyKernels/dsv3FusedAGemm.cu

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -601,6 +601,8 @@ __global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
601601
}
602602
}
603603
__syncthreads();
604+
asm volatile("griddepcontrol.wait;");
605+
asm volatile("griddepcontrol.launch_dependents;");
604606

605607
if (warp_idx < 2)
606608
{
@@ -622,7 +624,6 @@ __global__ __launch_bounds__(256, 1) void fused_a_gemm_kernel(
622624
mma_computer.issue_mainloop();
623625
mma_computer.epi();
624626
}
625-
asm volatile("griddepcontrol.launch_dependents;");
626627
#endif
627628
}
628629

cpp/tensorrt_llm/kernels/userbuffers/userbuffers.cu

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ __global__ void __launch_bounds__(MAX_THREADS)
5656
userbuffers_fp16_sum_inplace_gpu_rw(int const op, int const flagoffset, int const firstrank, int const myrank,
5757
int const gpustep, size_t const lineoffset, int const numlines, void** commbuff, int const handleridx)
5858
{
59-
#if __CUDA_ARCH__ >= 900
60-
cudaTriggerProgrammaticLaunchCompletion();
61-
#endif
6259
__shared__ int4* userptr[RANKS];
6360
int *flagptr, physgpu, targetgpu, *myptr;
6461
int *reduceidptr, reduce_id;
62+
#if __CUDA_ARCH__ >= 900
63+
cudaGridDependencySynchronize();
64+
#endif
6565
if (threadIdx.x < RANKS)
6666
{
6767
physgpu = myrank * gpustep + firstrank;
@@ -72,9 +72,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
7272
reduce_id = next_flag(*reduceidptr);
7373
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
7474
myptr += blockflagoffset;
75-
#if __CUDA_ARCH__ >= 900
76-
cudaGridDependencySynchronize();
77-
#endif
7875
flagptr[physgpu] = reduce_id;
7976
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu + handleridx]);
8077
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
@@ -130,19 +127,22 @@ __global__ void __launch_bounds__(MAX_THREADS)
130127
}
131128
if (threadIdx.x == 0 && blockIdx.x == 0)
132129
*reduceidptr = reduce_id;
130+
#if __CUDA_ARCH__ >= 900
131+
cudaTriggerProgrammaticLaunchCompletion();
132+
#endif
133133
} // fp16 inplace reduce kernel (Hopper)
134134

135135
template <typename DType, int RANKS>
136136
__global__ void __launch_bounds__(MAX_THREADS)
137137
userbuffers_fp16_sum_inplace_gpu_rr(int const op, int const flagoffset, int const firstrank, int const myrank,
138138
int const gpustep, size_t const lineoffset, int const numlines, void** commbuff, int const handleridx)
139139
{
140-
#if __CUDA_ARCH__ >= 900
141-
cudaTriggerProgrammaticLaunchCompletion();
142-
#endif
143140
__shared__ int4* userptr[RANKS];
144141
int *flagptr, physgpu, targetgpu, *myptr;
145142
int *reduceidptr, reduce_id;
143+
#if __CUDA_ARCH__ >= 900
144+
cudaGridDependencySynchronize();
145+
#endif
146146
if (threadIdx.x < RANKS)
147147
{
148148
physgpu = myrank * gpustep + firstrank;
@@ -153,9 +153,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
153153
reduce_id = next_flag(*reduceidptr);
154154
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
155155
myptr += blockflagoffset;
156-
#if __CUDA_ARCH__ >= 900
157-
cudaGridDependencySynchronize();
158-
#endif
159156
flagptr[physgpu] = reduce_id;
160157
userptr[threadIdx.x] = reinterpret_cast<int4*>(commbuff[targetgpu + handleridx]);
161158
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
@@ -239,6 +236,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
239236
}
240237
if (threadIdx.x == 0 && blockIdx.x == 0)
241238
*reduceidptr = reduce_id;
239+
#if __CUDA_ARCH__ >= 900
240+
cudaTriggerProgrammaticLaunchCompletion();
241+
#endif
242242
} // fp16 inplace reduce kernel (Ampere)
243243

244244
#if __CUDA_ARCH__ >= 900
@@ -365,7 +365,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
365365
*reduceidptr = reduce_id;
366366
} // fp16 inplace reduce kernel (Hopper) MC
367367

368-
#else
368+
#else // __CUDA_ARCH__ >= 900
369369
template <typename DType, int RANKS, bool DISABLE_FP32_ACC>
370370
__global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_mc(int const op, int const flagoffset,
371371
int const firstrank, int const myrank, int const gpustep, size_t const lineoffset, int const numlines,
@@ -375,7 +375,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
375375
asm volatile("brkpt;\n");
376376
}
377377

378-
#endif
378+
#endif // __CUDA_ARCH__ >= 900
379379

380380
#define callranks(x) \
381381
if (ar_nvsize == x) \
@@ -568,13 +568,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
568568
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
569569
constexpr int SF_VEC_SIZE = 16;
570570
using PackedVec = PackedVec<DType>;
571-
cudaTriggerProgrammaticLaunchCompletion();
572571
float sf = *scale;
573572
__shared__ float s_variance;
574573
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
575574

576575
int *flagptr, physgpu, targetgpu, *myptr;
577576
int *reduceidptr, reduce_id;
577+
cudaGridDependencySynchronize();
578578
if (threadIdx.x < RANKS)
579579
{
580580
physgpu = myrank * gpustep + firstrank;
@@ -585,7 +585,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
585585
reduce_id = next_flag(*reduceidptr);
586586
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
587587
myptr += blockflagoffset;
588-
cudaGridDependencySynchronize();
589588
flagptr[physgpu] = reduce_id;
590589
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
591590
reduce_id = next_flag(reduce_id);
@@ -670,6 +669,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
670669
}
671670
if (threadIdx.x == 0 && blockIdx.x == 0)
672671
*reduceidptr = reduce_id;
672+
cudaTriggerProgrammaticLaunchCompletion();
673673
#endif
674674
}
675675

@@ -684,13 +684,13 @@ __global__ void __launch_bounds__(MAX_THREADS)
684684
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000)
685685
constexpr int SF_VEC_SIZE = 16;
686686
using PackedVec = PackedVec<DType>;
687-
cudaTriggerProgrammaticLaunchCompletion();
688687
float sf = *scale;
689688
__shared__ float s_variance;
690689
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
691690

692691
int *flagptr, physgpu, targetgpu, *myptr;
693692
int *reduceidptr, reduce_id;
693+
cudaGridDependencySynchronize();
694694
if (threadIdx.x < RANKS)
695695
{
696696
physgpu = myrank * gpustep + firstrank;
@@ -701,7 +701,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
701701
reduce_id = next_flag(*reduceidptr);
702702
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
703703
myptr += blockflagoffset;
704-
cudaGridDependencySynchronize();
705704
flagptr[physgpu] = reduce_id;
706705
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
707706
}
@@ -772,6 +771,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
772771
}
773772
if (threadIdx.x == 0 && blockIdx.x == 0)
774773
*reduceidptr = reduce_id;
774+
cudaTriggerProgrammaticLaunchCompletion();
775775
#endif
776776
}
777777

@@ -784,11 +784,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
784784
float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS, float4* mc_ptr_out,
785785
size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
786786
{
787-
cudaTriggerProgrammaticLaunchCompletion();
788787
__shared__ float s_variance;
789788
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
790789
int *flagptr, physgpu, targetgpu, *myptr;
791790
int *reduceidptr, reduce_id;
791+
cudaGridDependencySynchronize();
792792
if (threadIdx.x < RANKS)
793793
{
794794
physgpu = myrank * gpustep + firstrank;
@@ -799,7 +799,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
799799
reduce_id = next_flag(*reduceidptr);
800800
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
801801
myptr += blockflagoffset;
802-
cudaGridDependencySynchronize();
803802
flagptr[physgpu] = reduce_id;
804803
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
805804
reduce_id = next_flag(reduce_id);
@@ -874,6 +873,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
874873
}
875874
if (threadIdx.x == 0 && blockIdx.x == 0)
876875
*reduceidptr = reduce_id;
876+
cudaTriggerProgrammaticLaunchCompletion();
877877
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused
878878

879879
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
@@ -883,11 +883,11 @@ __global__ void __launch_bounds__(MAX_THREADS)
883883
int const handleridx, float4* mc_ptr, DType const* beta, DType const* gamma, float const eps, int const RANKS,
884884
uint4* uc_ptr_out, size_t const out_lineoffset, uint4* residual_in, uint4* residual_out, int res_offset)
885885
{
886-
cudaTriggerProgrammaticLaunchCompletion();
887886
__shared__ float s_variance;
888887
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
889888
int *flagptr, physgpu, targetgpu, *myptr;
890889
int *reduceidptr, reduce_id;
890+
cudaGridDependencySynchronize();
891891
if (threadIdx.x < RANKS)
892892
{
893893
physgpu = myrank * gpustep + firstrank;
@@ -898,7 +898,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
898898
reduce_id = next_flag(*reduceidptr);
899899
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
900900
myptr += blockflagoffset;
901-
cudaGridDependencySynchronize();
902901
flagptr[physgpu] = reduce_id;
903902
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
904903
}
@@ -962,6 +961,7 @@ __global__ void __launch_bounds__(MAX_THREADS)
962961
}
963962
if (threadIdx.x == 0 && blockIdx.x == 0)
964963
*reduceidptr = reduce_id;
964+
cudaTriggerProgrammaticLaunchCompletion();
965965
} // fp16 inplace reduce kernel (Hopper) MC with rmsNorm fused oneshot
966966

967967
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
@@ -971,13 +971,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
971971
float const eps, int const RANKS, float2* mc_ptr_out, size_t const out_lineoffset, float const* scale,
972972
uint4* residual_in, uint4* residual_out, int res_offset)
973973
{
974-
cudaTriggerProgrammaticLaunchCompletion();
975974
float const sf = 1.f / (*scale);
976975
__shared__ float s_variance;
977976
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
978977

979978
int *flagptr, physgpu, targetgpu, *myptr;
980979
int *reduceidptr, reduce_id;
980+
cudaGridDependencySynchronize();
981981
if (threadIdx.x < RANKS)
982982
{
983983
physgpu = myrank * gpustep + firstrank;
@@ -988,7 +988,6 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
988988
reduce_id = next_flag(*reduceidptr);
989989
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
990990
myptr += blockflagoffset;
991-
cudaGridDependencySynchronize();
992991
flagptr[physgpu] = reduce_id;
993992
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
994993
reduce_id = next_flag(reduce_id);
@@ -1066,6 +1065,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
10661065
}
10671066
if (threadIdx.x == 0 && blockIdx.x == 0)
10681067
*reduceidptr = reduce_id;
1068+
cudaTriggerProgrammaticLaunchCompletion();
10691069
} // quant kernel fp16->fp8 twoshot
10701070

10711071
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
@@ -1075,13 +1075,13 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
10751075
float const eps, int const RANKS, uint2* mc_ptr_out, size_t const out_lineoffset, float const* scale,
10761076
uint4* residual_in, uint4* residual_out, int res_offset)
10771077
{
1078-
cudaTriggerProgrammaticLaunchCompletion();
10791078
float const sf = 1.f / (*scale);
10801079
__shared__ float s_variance;
10811080
int hidden_dim = blockDim.x * UNROLL_NLINES * sizeof(int4) / sizeof(DType);
10821081

10831082
int *flagptr, physgpu, targetgpu, *myptr;
10841083
int *reduceidptr, reduce_id;
1084+
cudaGridDependencySynchronize();
10851085
if (threadIdx.x < RANKS)
10861086
{
10871087
physgpu = myrank * gpustep + firstrank;
@@ -1092,7 +1092,6 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
10921092
reduce_id = next_flag(*reduceidptr);
10931093
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
10941094
myptr += blockflagoffset;
1095-
cudaGridDependencySynchronize();
10961095
flagptr[physgpu] = reduce_id;
10971096
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
10981097
}
@@ -1160,6 +1159,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
11601159
}
11611160
if (threadIdx.x == 0 && blockIdx.x == 0)
11621161
*reduceidptr = reduce_id;
1162+
cudaTriggerProgrammaticLaunchCompletion();
11631163
} // quant kernel fp16->fp8 oneshot
11641164

11651165
template <typename DType, int UNROLL_NLINES>
@@ -1168,9 +1168,9 @@ __global__ void __launch_bounds__(MAX_THREADS)
11681168
int const myrank, int const gpustep, size_t const lineoffset, int const numlines, void** commbuff,
11691169
int const handleridx, float4* mc_ptr, int const RANKS, uint4* residual_in, int res_offset)
11701170
{
1171-
cudaTriggerProgrammaticLaunchCompletion();
11721171
int *flagptr, physgpu, targetgpu, *myptr;
11731172
int *reduceidptr, reduce_id;
1173+
cudaGridDependencySynchronize();
11741174
if (threadIdx.x < RANKS)
11751175
{
11761176
physgpu = myrank * gpustep + firstrank;
@@ -1181,7 +1181,6 @@ __global__ void __launch_bounds__(MAX_THREADS)
11811181
reduce_id = next_flag(*reduceidptr);
11821182
flagptr = (reinterpret_cast<int*>(commbuff[targetgpu])) + flagoffset + blockflagoffset;
11831183
myptr += blockflagoffset;
1184-
cudaGridDependencySynchronize();
11851184
flagptr[physgpu] = reduce_id;
11861185
multi_gpu_block_barrier(reduce_id, (int volatile*) &myptr[targetgpu]);
11871186
reduce_id = next_flag(reduce_id);
@@ -1217,9 +1216,10 @@ __global__ void __launch_bounds__(MAX_THREADS)
12171216
}
12181217
if (threadIdx.x == 0 && blockIdx.x == 0)
12191218
*reduceidptr = reduce_id;
1219+
cudaTriggerProgrammaticLaunchCompletion();
12201220
} // residual allgather kernel
12211221

1222-
#else
1222+
#else // __CUDA_ARCH__ >= 900
12231223
template <typename DType, int UNROLL_NLINES, bool DISABLE_FP32_ACC>
12241224
__global__ void __launch_bounds__(MAX_THREADS)
12251225
userbuffers_fp16_sum_gpu_mc_rmsnorm(int const op, int const flagoffset, int const firstrank, int const myrank,
@@ -1274,7 +1274,7 @@ __global__ void __launch_bounds__(MAX_THREADS) userbuffers_fp16_sum_inplace_gpu_
12741274
asm volatile("brkpt;\n");
12751275
}
12761276

1277-
#endif
1277+
#endif // __CUDA_ARCH__ >= 900
12781278

12791279
#define callranksMC_RMSNORM_QUANT(x) \
12801280
if (nlines == x) \

0 commit comments

Comments
 (0)