Skip to content

Commit 7d16f3a

Browse files
authored
[https://nvbugs/5788127][fix] Use uint64_t as the dtype of lamport_buffer_size to avoid overflow (#10499)
Signed-off-by: Yilin Zhang <18275976+yilin-void@users.noreply.github.com>
1 parent bdaee87 commit 7d16f3a

File tree

5 files changed

+55
-38
lines changed

5 files changed

+55
-38
lines changed

cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,9 @@ struct LamportComm
7070
{
7171
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
7272
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
73-
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
73+
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
7474
flag_value = *flag_ptr;
75-
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
75+
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
7676
clear_size = *clear_ptr;
7777
int data_offset = flag_value % 3;
7878
int clear_offset = (flag_value + 2) % 3;
@@ -88,7 +88,7 @@ struct LamportComm
8888
}
8989
}
9090

91-
__device__ __forceinline__ void update(int new_clear_size)
91+
__device__ __forceinline__ void update(int64_t new_clear_size)
9292
{
9393
if (blockIdx.x == 0 && threadIdx.x == 0)
9494
{
@@ -103,10 +103,10 @@ struct LamportComm
103103

104104
int* counter_ptr;
105105
int* flag_ptr;
106-
int* clear_ptr;
106+
int64_t* clear_ptr;
107107
uint8_t* data_bufs[NRanks];
108108
uint8_t* clear_buf;
109-
int clear_size;
109+
int64_t clear_size;
110110
int flag_value;
111111
};
112112

cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,18 @@ TRTLLM_NAMESPACE_BEGIN
2121
namespace kernels::ar_fusion
2222
{
2323

24-
__global__ void lamport_initialize_kernel(float* ptr, int size)
24+
__global__ void lamport_initialize_kernel(float* ptr, size_t size)
2525
{
26-
int idx = blockIdx.x * blockDim.x + threadIdx.x;
26+
size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
2727
if (idx >= size)
2828
return;
2929
ptr[idx] = -0.f;
3030
}
3131

32-
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
32+
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream)
3333
{
34-
int grid_size = (bytes + 127) / 128;
35-
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
34+
int grid_size = static_cast<int>((bytes + 1023) / 1024);
35+
lamport_initialize_kernel<<<grid_size, 1024, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
3636
}
3737

3838
Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
@@ -45,10 +45,11 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
4545
int device_id;
4646
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
4747
m_buffer_mgr = std::make_shared<tensorrt_llm::runtime::BufferManager>(m_cuda_stream);
48-
int buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
49-
int flag_size = tp_size * kBarrierFlagCount * sizeof(int);
50-
int lamport_comm_size = tp_size * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
51-
int lamport_buffer_size = 3 * lamport_comm_size;
48+
size_t buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
49+
size_t flag_size = tp_size * kBarrierFlagCount * sizeof(int);
50+
size_t lamport_comm_size
51+
= static_cast<size_t>(tp_size) * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
52+
size_t lamport_buffer_size = 3 * lamport_comm_size;
5253
for (auto size : {buffer_size, flag_size, lamport_buffer_size})
5354
{
5455
m_ipc_mem_handles.emplace_back(size, *m_buffer_mgr, m_world_config, p2p_supported);
@@ -61,20 +62,20 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
6162
workspace.push_back(ipc_mem_handle.getCommPtrs()[r]);
6263
}
6364
}
64-
// atomic flag read counter
65-
// kernel_flag_ptr[0] = 0;
66-
// non-lamport flag
67-
// kernel_flag_ptr[1] = 0;
68-
// lamport flag
69-
// kernel_flag_ptr[2] = 0;
70-
// lamport triple buffer offset
71-
// kernel_flag_ptr[3] = lamport_comm_size;
72-
// lamport clear size
73-
// kernel_flag_ptr[4] = 0;
74-
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 5 * sizeof(int)));
75-
std::vector<int> h_data{0, 0, 0, lamport_comm_size, 0};
76-
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_data.data(), 5 * sizeof(int), cudaMemcpyHostToDevice));
65+
// flag_buffer[0], atomic flag read counter
66+
// flag_buffer[1], non-lamport flag
67+
// flag_buffer[2], lamport flag
68+
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 3 * sizeof(int)));
69+
std::vector<int> h_flag_data{0, 0, 0};
70+
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_flag_data.data(), 3 * sizeof(int), cudaMemcpyHostToDevice));
7771
workspace.push_back(m_flag_d_ptr);
72+
// layout_buffer[0], clear size for next lamport kernel
73+
// layout_buffer[1], triple buffer offset for lamport kernel
74+
TLLM_CUDA_CHECK(cudaMalloc(&m_layout_d_ptr, 2 * sizeof(int64_t)));
75+
std::vector<int64_t> h_layout_data{0, static_cast<int64_t>(lamport_comm_size)};
76+
TLLM_CUDA_CHECK(cudaMemcpy(m_layout_d_ptr, h_layout_data.data(), 2 * sizeof(int64_t), cudaMemcpyHostToDevice));
77+
workspace.push_back(m_layout_d_ptr);
78+
7879
TLLM_CUDA_CHECK(cudaMalloc(&m_workspace, workspace.size() * sizeof(void*)));
7980
TLLM_CUDA_CHECK(
8081
cudaMemcpy(m_workspace, workspace.data(), workspace.size() * sizeof(void*), cudaMemcpyHostToDevice));
@@ -87,6 +88,10 @@ Workspace::~Workspace()
8788
{
8889
TLLM_CUDA_CHECK(cudaFree(m_flag_d_ptr));
8990
}
91+
if (m_layout_d_ptr)
92+
{
93+
TLLM_CUDA_CHECK(cudaFree(m_layout_d_ptr));
94+
}
9095
if (m_workspace)
9196
{
9297
TLLM_CUDA_CHECK(cudaFree(m_workspace));

cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,10 @@ class Workspace
4141
void* m_workspace;
4242
std::shared_ptr<tensorrt_llm::runtime::CudaStream> m_cuda_stream;
4343
void* m_flag_d_ptr;
44+
void* m_layout_d_ptr;
4445
};
4546

46-
void lamport_initialize(void* ptr, int bytes, cudaStream_t stream);
47+
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream);
4748
} // namespace kernels::ar_fusion
4849

4950
TRTLLM_NAMESPACE_END

cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@ struct LamportComm
3131
{
3232
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
3333
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
34-
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
34+
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
3535
flag_value = *flag_ptr;
36-
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
36+
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
3737
clear_size = *clear_ptr;
3838
int data_offset = flag_value % 3;
3939
int clear_offset = (flag_value + 2) % 3;
@@ -49,7 +49,7 @@ struct LamportComm
4949
}
5050
}
5151

52-
__device__ __forceinline__ void update(int new_clear_size)
52+
__device__ __forceinline__ void update(int64_t new_clear_size)
5353
{
5454
if (blockIdx.x == 0 && threadIdx.x == 0)
5555
{
@@ -64,10 +64,10 @@ struct LamportComm
6464

6565
int* counter_ptr;
6666
int* flag_ptr;
67-
int* clear_ptr;
67+
int64_t* clear_ptr;
6868
uint8_t* data_bufs[NRanks];
6969
uint8_t* clear_buf;
70-
int clear_size;
70+
int64_t clear_size;
7171
int flag_value;
7272
};
7373

tensorrt_llm/plugin/plugin.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -737,14 +737,25 @@ def allocate_allreduce_fusion_workspace(
737737
lamport_buffers.local_ptr,
738738
3 * lamport_buffers_size,
739739
)
740-
flag_buffer = torch.tensor([0, 0, 0, lamport_buffers_size, 0],
741-
dtype=torch.int,
742-
device="cuda")
743-
buffers = [ipc_buffers, ipc_barriers, lamport_buffers, flag_buffer]
740+
# flag_buffer[0], atomic flag read counter
741+
# flag_buffer[1], non-lamport flag
742+
# flag_buffer[2], lamport flag
743+
flag_buffer = torch.tensor([0, 0, 0], dtype=torch.int, device="cuda")
744+
# layout_buffer[0], clear size for next lamport kernel
745+
# layout_buffer[1], triple buffer offset for lamport kernel
746+
layout_buffer = torch.tensor([0, lamport_buffers_size],
747+
dtype=torch.int64,
748+
device="cuda")
749+
750+
buffers = [
751+
ipc_buffers, ipc_barriers, lamport_buffers, flag_buffer,
752+
layout_buffer
753+
]
744754

745755
return buffers, torch.tensor(
746756
ipc_buffers.serialize() + ipc_barriers.serialize() +
747-
lamport_buffers.serialize() + [flag_buffer.data_ptr()],
757+
lamport_buffers.serialize() + [flag_buffer.data_ptr()] +
758+
[layout_buffer.data_ptr()],
748759
dtype=torch.int64,
749760
device="cuda")
750761

0 commit comments

Comments
 (0)