Skip to content

Commit 4bdf684

Browse files
committed
Try fix the accuracy issue with warp reduce.
Signed-off-by: Shiyu Li <[email protected]>
1 parent 095b7a3 commit 4bdf684

File tree

2 files changed

+38
-15
lines changed

2 files changed

+38
-15
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu

Lines changed: 37 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -377,40 +377,63 @@ inline __device__ T add(T a, T b)
377377

378378
#define FINAL_MASK 0xffffffff
379379
#define WARP_SIZE 32
380+
#define LOG2_WARP_SIZE 5
381+
#define LANE_ID_MASK 0x1f
380382

381383
template <typename T>
382-
__inline__ __device__ T warpReduceSum(T val)
384+
__inline__ __device__ T warpReduceSumPartial(T val)
383385
{
384-
// Get the actual number of active threads in this warp
385-
int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1)));
386-
unsigned int mask = (1U << active_warp_size) - 1;
386+
int lane_id = threadIdx.x & LANE_ID_MASK;
387+
// This function should only be called on the last warp
388+
int warp_size = blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1));
389+
unsigned int active_mask = (1U << warp_size) - 1;
387390

388391
#pragma unroll
389392
for (int offset = 16; offset > 0; offset >>= 1)
390393
{
391-
if (offset < active_warp_size)
392-
{
393-
val = add<T>(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE));
394-
}
394+
int target_lane = lane_id ^ offset;
395+
auto tmp = __shfl_xor_sync(active_mask, val, offset, WARP_SIZE);
396+
val = add<T>(val, target_lane < warp_size ? tmp : 0);
397+
}
398+
return val;
399+
}
400+
401+
template <typename T>
402+
inline __device__ T warpReduceSumFull(T val)
403+
{
404+
#pragma unroll
405+
for (int offset = 16; offset > 0; offset >>= 1)
406+
{
407+
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, offset, WARP_SIZE));
395408
}
396409
return val;
397410
}
398411

399412
inline __device__ float block_reduce_sum(float val)
400413
{
401414
__shared__ float smem[WARP_SIZE];
402-
int lane_id = threadIdx.x % WARP_SIZE;
403-
int warp_id = threadIdx.x / WARP_SIZE;
404-
int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps
415+
int lane_id = threadIdx.x & LANE_ID_MASK;
416+
int warp_id = threadIdx.x >> LOG2_WARP_SIZE;
417+
bool has_partial_warp = (blockDim.x % WARP_SIZE) != 0;
418+
int warp_num = (blockDim.x + WARP_SIZE - 1) >> LOG2_WARP_SIZE; // Ceiling division to include partial warps
405419

406-
val = warpReduceSum(val);
420+
val = (has_partial_warp && (warp_id == warp_num - 1)) ? warpReduceSumPartial(val) : warpReduceSumFull(val);
407421
if (lane_id == 0)
408422
{
409423
smem[warp_id] = val;
410424
}
411425
__syncthreads();
412-
val = lane_id < warp_num ? smem[lane_id] : 0.f;
413-
val = warpReduceSum(val);
426+
if (warp_id == 0)
427+
{
428+
val = lane_id < warp_num ? smem[lane_id] : 0.f;
429+
val = (has_partial_warp && warp_num == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val);
430+
if (lane_id == 0)
431+
{
432+
smem[0] = val;
433+
}
434+
}
435+
__syncthreads();
436+
val = smem[0];
414437

415438
return val;
416439
}

tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def func(input, residual, norm_weight, eps, enable_fusion):
161161
)
162162

163163

164-
@pytest.mark.skip(reason="https://nvbugs/5597647")
164+
#@pytest.mark.skip(reason="https://nvbugs/5597647")
165165
@pytest.mark.skipif(torch.cuda.device_count() < 2,
166166
reason="needs 2 GPUs to run this test")
167167
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)