Skip to content

Commit a81a658

Browse files
committed
Try fix the accuracy issue with warp reduce.
Signed-off-by: Shiyu Li <[email protected]>
1 parent 095b7a3 commit a81a658

File tree

2 files changed

+18
-10
lines changed

2 files changed

+18
-10
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -381,17 +381,16 @@ inline __device__ T add(T a, T b)
381381
template <typename T>
382382
__inline__ __device__ T warpReduceSum(T val)
383383
{
384-
// Get the actual number of active threads in this warp
385-
int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1)));
386-
unsigned int mask = (1U << active_warp_size) - 1;
384+
int lane_id = threadIdx.x & 0x1f;
385+
int warp_size = blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1));
386+
unsigned int active_mask = (1U << warp_size) - 1;
387387

388388
#pragma unroll
389389
for (int offset = 16; offset > 0; offset >>= 1)
390390
{
391-
if (offset < active_warp_size)
392-
{
393-
val = add<T>(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE));
394-
}
391+
int target_lane = lane_id ^ offset;
392+
auto tmp = __shfl_xor_sync(active_mask, val, offset, WARP_SIZE);
393+
val = add<T>(val, target_lane < warp_size ? tmp : 0);
395394
}
396395
return val;
397396
}
@@ -409,8 +408,17 @@ inline __device__ float block_reduce_sum(float val)
409408
smem[warp_id] = val;
410409
}
411410
__syncthreads();
412-
val = lane_id < warp_num ? smem[lane_id] : 0.f;
413-
val = warpReduceSum(val);
411+
if (warp_id == 0)
412+
{
413+
val = lane_id < warp_num ? smem[lane_id] : 0.f;
414+
val = warpReduceSum(val);
415+
if (lane_id == 0)
416+
{
417+
val = smem[0];
418+
}
419+
}
420+
__syncthreads();
421+
val = smem[0];
414422

415423
return val;
416424
}

tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def func(input, residual, norm_weight, eps, enable_fusion):
161161
)
162162

163163

164-
@pytest.mark.skip(reason="https://nvbugs/5597647")
164+
#@pytest.mark.skip(reason="https://nvbugs/5597647")
165165
@pytest.mark.skipif(torch.cuda.device_count() < 2,
166166
reason="needs 2 GPUs to run this test")
167167
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)