Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -377,40 +377,63 @@ inline __device__ T add(T a, T b)

#define FINAL_MASK 0xffffffff
#define WARP_SIZE 32
#define LOG2_WARP_SIZE 5
#define LANE_ID_MASK 0x1f

template <typename T>
__inline__ __device__ T warpReduceSum(T val)
__inline__ __device__ T warpReduceSumPartial(T val)
{
// Get the actual number of active threads in this warp
int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1)));
unsigned int mask = (1U << active_warp_size) - 1;
int lane_id = threadIdx.x & LANE_ID_MASK;
// This function should only be called on the last warp
int warp_size = blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1));
unsigned int active_mask = (1U << warp_size) - 1;

#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
{
if (offset < active_warp_size)
{
val = add<T>(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE));
}
int target_lane = lane_id ^ offset;
auto tmp = __shfl_xor_sync(active_mask, val, offset, WARP_SIZE);
val = add<T>(val, target_lane < warp_size ? tmp : 0);
}
return val;
}

template <typename T>
inline __device__ T warpReduceSumFull(T val)
{
#pragma unroll
for (int offset = 16; offset > 0; offset >>= 1)
{
val = add<T>(val, __shfl_xor_sync(FINAL_MASK, val, offset, WARP_SIZE));
}
return val;
}

inline __device__ float block_reduce_sum(float val)
{
__shared__ float smem[WARP_SIZE];
int lane_id = threadIdx.x % WARP_SIZE;
int warp_id = threadIdx.x / WARP_SIZE;
int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps
int lane_id = threadIdx.x & LANE_ID_MASK;
int warp_id = threadIdx.x >> LOG2_WARP_SIZE;
bool has_partial_warp = (blockDim.x % WARP_SIZE) != 0;
int warp_num = (blockDim.x + WARP_SIZE - 1) >> LOG2_WARP_SIZE; // Ceiling division to include partial warps

val = warpReduceSum(val);
val = (has_partial_warp && (warp_id == warp_num - 1)) ? warpReduceSumPartial(val) : warpReduceSumFull(val);
if (lane_id == 0)
{
smem[warp_id] = val;
}
__syncthreads();
val = lane_id < warp_num ? smem[lane_id] : 0.f;
val = warpReduceSum(val);
if (warp_id == 0)
{
val = lane_id < warp_num ? smem[lane_id] : 0.f;
val = (has_partial_warp && warp_num == 1) ? warpReduceSumPartial(val) : warpReduceSumFull(val);
if (lane_id == 0)
{
smem[0] = val;
}
}
__syncthreads();
val = smem[0];

return val;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/unittest/_torch/multi_gpu/test_mnnvl_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def func(input, residual, norm_weight, eps, enable_fusion):
)


@pytest.mark.skip(reason="https://nvbugs/5597647")
#@pytest.mark.skip(reason="https://nvbugs/5597647")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this line completely instead of commenting it out.

@pytest.mark.skipif(torch.cuda.device_count() < 2,
reason="needs 2 GPUs to run this test")
@pytest.mark.parametrize(
Expand Down