diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu index 25c662534d9..75bbddb5664 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu @@ -70,9 +70,9 @@ struct LamportComm { counter_ptr = &reinterpret_cast(workspace[NRanks * 3])[0]; flag_ptr = &reinterpret_cast(workspace[NRanks * 3])[2]; - clear_ptr = &reinterpret_cast(workspace[NRanks * 3])[4]; + clear_ptr = &reinterpret_cast(workspace[NRanks * 3 + 1])[0]; flag_value = *flag_ptr; - int comm_size = reinterpret_cast(workspace[NRanks * 3])[3]; + auto comm_size = reinterpret_cast(workspace[NRanks * 3 + 1])[1]; clear_size = *clear_ptr; int data_offset = flag_value % 3; int clear_offset = (flag_value + 2) % 3; @@ -88,7 +88,7 @@ struct LamportComm } } - __device__ __forceinline__ void update(int new_clear_size) + __device__ __forceinline__ void update(int64_t new_clear_size) { if (blockIdx.x == 0 && threadIdx.x == 0) { @@ -103,10 +103,10 @@ struct LamportComm int* counter_ptr; int* flag_ptr; - int* clear_ptr; + int64_t* clear_ptr; uint8_t* data_bufs[NRanks]; uint8_t* clear_buf; - int clear_size; + int64_t clear_size; int flag_value; }; diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu index 3c4b4b50499..a594fb5933d 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu @@ -21,18 +21,18 @@ TRTLLM_NAMESPACE_BEGIN namespace kernels::ar_fusion { -__global__ void lamport_initialize_kernel(float* ptr, int size) +__global__ void lamport_initialize_kernel(float* ptr, size_t size) { - int idx = blockIdx.x * blockDim.x + threadIdx.x; + size_t idx = static_cast(blockIdx.x) * blockDim.x + threadIdx.x; if (idx >= size) return; ptr[idx] = -0.f; } -void lamport_initialize(void* ptr, int bytes, cudaStream_t stream) +void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream) { - int grid_size = (bytes + 127) / 128; - lamport_initialize_kernel<<>>(reinterpret_cast(ptr), bytes / sizeof(float)); + int grid_size = static_cast((bytes + 1023) / 1024); + lamport_initialize_kernel<<>>(reinterpret_cast(ptr), bytes / sizeof(float)); } Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim, @@ -45,10 +45,11 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim, int device_id; TLLM_CUDA_CHECK(cudaGetDevice(&device_id)); m_buffer_mgr = std::make_shared(m_cuda_stream); - int buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half); - int flag_size = tp_size * kBarrierFlagCount * sizeof(int); - int lamport_comm_size = tp_size * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half); - int lamport_buffer_size = 3 * lamport_comm_size; + size_t buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half); + size_t flag_size = tp_size * kBarrierFlagCount * sizeof(int); + size_t lamport_comm_size + = static_cast(tp_size) * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half); + size_t lamport_buffer_size = 3 * lamport_comm_size; for (auto size : {buffer_size, flag_size, lamport_buffer_size}) { m_ipc_mem_handles.emplace_back(size, *m_buffer_mgr, m_world_config, p2p_supported); @@ -61,20 +62,20 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim, workspace.push_back(ipc_mem_handle.getCommPtrs()[r]); } } - // atomic flag read counter - // kernel_flag_ptr[0] = 0; - // non-lamport flag - // kernel_flag_ptr[1] = 0; - // lamport flag - // kernel_flag_ptr[2] = 0; - // lamport triple buffer offset - // kernel_flag_ptr[3] = lamport_comm_size; - // lamport clear size - // kernel_flag_ptr[4] = 0; - TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 5 * sizeof(int))); - std::vector h_data{0, 0, 0, lamport_comm_size, 0}; - TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_data.data(), 5 * sizeof(int), cudaMemcpyHostToDevice)); + // flag_buffer[0], atomic flag read counter + // flag_buffer[1], non-lamport flag + // flag_buffer[2], lamport flag + TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 3 * sizeof(int))); + std::vector h_flag_data{0, 0, 0}; + TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_flag_data.data(), 3 * sizeof(int), cudaMemcpyHostToDevice)); workspace.push_back(m_flag_d_ptr); + // layout_buffer[0], clear size for next lamport kernel + // layout_buffer[1], triple buffer offset for lamport kernel + TLLM_CUDA_CHECK(cudaMalloc(&m_layout_d_ptr, 2 * sizeof(int64_t))); + std::vector h_layout_data{0, static_cast(lamport_comm_size)}; + TLLM_CUDA_CHECK(cudaMemcpy(m_layout_d_ptr, h_layout_data.data(), 2 * sizeof(int64_t), cudaMemcpyHostToDevice)); + workspace.push_back(m_layout_d_ptr); + TLLM_CUDA_CHECK(cudaMalloc(&m_workspace, workspace.size() * sizeof(void*))); TLLM_CUDA_CHECK( cudaMemcpy(m_workspace, workspace.data(), workspace.size() * sizeof(void*), cudaMemcpyHostToDevice)); @@ -87,6 +88,10 @@ Workspace::~Workspace() { TLLM_CUDA_CHECK(cudaFree(m_flag_d_ptr)); } + if (m_layout_d_ptr) + { + TLLM_CUDA_CHECK(cudaFree(m_layout_d_ptr)); + } if (m_workspace) { TLLM_CUDA_CHECK(cudaFree(m_workspace)); diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h index 055d29c3a05..4bf31f15aa4 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.h @@ -41,9 +41,10 @@ class Workspace void* m_workspace; std::shared_ptr m_cuda_stream; void* m_flag_d_ptr; + void* m_layout_d_ptr; }; -void lamport_initialize(void* ptr, int bytes, cudaStream_t stream); +void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream); } // namespace kernels::ar_fusion TRTLLM_NAMESPACE_END diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu index 44a32f9a1f3..3930c63ec5c 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu @@ -31,9 +31,9 @@ struct LamportComm { counter_ptr = &reinterpret_cast(workspace[NRanks * 3])[0]; flag_ptr = &reinterpret_cast(workspace[NRanks * 3])[2]; - clear_ptr = &reinterpret_cast(workspace[NRanks * 3])[4]; + clear_ptr = &reinterpret_cast(workspace[NRanks * 3 + 1])[0]; flag_value = *flag_ptr; - int comm_size = reinterpret_cast(workspace[NRanks * 3])[3]; + auto comm_size = reinterpret_cast(workspace[NRanks * 3 + 1])[1]; clear_size = *clear_ptr; int data_offset = flag_value % 3; int clear_offset = (flag_value + 2) % 3; @@ -49,7 +49,7 @@ struct LamportComm } } - __device__ __forceinline__ void update(int new_clear_size) + __device__ __forceinline__ void update(int64_t new_clear_size) { if (blockIdx.x == 0 && threadIdx.x == 0) { @@ -64,10 +64,10 @@ struct LamportComm int* counter_ptr; int* flag_ptr; - int* clear_ptr; + int64_t* clear_ptr; uint8_t* data_bufs[NRanks]; uint8_t* clear_buf; - int clear_size; + int64_t clear_size; int flag_value; }; diff --git a/tensorrt_llm/plugin/plugin.py b/tensorrt_llm/plugin/plugin.py index 60e12e98207..154194cab14 100644 --- a/tensorrt_llm/plugin/plugin.py +++ b/tensorrt_llm/plugin/plugin.py @@ -737,14 +737,25 @@ def allocate_allreduce_fusion_workspace( lamport_buffers.local_ptr, 3 * lamport_buffers_size, ) - flag_buffer = torch.tensor([0, 0, 0, lamport_buffers_size, 0], - dtype=torch.int, - device="cuda") - buffers = [ipc_buffers, ipc_barriers, lamport_buffers, flag_buffer] + # flag_buffer[0], atomic flag read counter + # flag_buffer[1], non-lamport flag + # flag_buffer[2], lamport flag + flag_buffer = torch.tensor([0, 0, 0], dtype=torch.int, device="cuda") + # layout_buffer[0], clear size for next lamport kernel + # layout_buffer[1], triple buffer offset for lamport kernel + layout_buffer = torch.tensor([0, lamport_buffers_size], + dtype=torch.int64, + device="cuda") + + buffers = [ + ipc_buffers, ipc_barriers, lamport_buffers, flag_buffer, + layout_buffer + ] return buffers, torch.tensor( ipc_buffers.serialize() + ipc_barriers.serialize() + - lamport_buffers.serialize() + [flag_buffer.data_ptr()], + lamport_buffers.serialize() + [flag_buffer.data_ptr()] + + [layout_buffer.data_ptr()], dtype=torch.int64, device="cuda")