Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,9 @@ struct LamportComm
{
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
flag_value = *flag_ptr;
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
clear_size = *clear_ptr;
int data_offset = flag_value % 3;
int clear_offset = (flag_value + 2) % 3;
Expand All @@ -88,7 +88,7 @@ struct LamportComm
}
}

__device__ __forceinline__ void update(int new_clear_size)
__device__ __forceinline__ void update(int64_t new_clear_size)
{
if (blockIdx.x == 0 && threadIdx.x == 0)
{
Expand All @@ -103,10 +103,10 @@ struct LamportComm

int* counter_ptr;
int* flag_ptr;
int* clear_ptr;
int64_t* clear_ptr;
uint8_t* data_bufs[NRanks];
uint8_t* clear_buf;
int clear_size;
int64_t clear_size;
int flag_value;
};

Expand Down
49 changes: 27 additions & 22 deletions cpp/tensorrt_llm/kernels/communicationKernels/allReduceWorkspace.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,18 @@ TRTLLM_NAMESPACE_BEGIN
namespace kernels::ar_fusion
{

__global__ void lamport_initialize_kernel(float* ptr, int size)
__global__ void lamport_initialize_kernel(float* ptr, size_t size)
{
int idx = blockIdx.x * blockDim.x + threadIdx.x;
size_t idx = static_cast<size_t>(blockIdx.x) * blockDim.x + threadIdx.x;
if (idx >= size)
return;
ptr[idx] = -0.f;
}

void lamport_initialize(void* ptr, int bytes, cudaStream_t stream)
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream)
{
int grid_size = (bytes + 127) / 128;
lamport_initialize_kernel<<<grid_size, 128, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
int grid_size = static_cast<int>((bytes + 1023) / 1024);
lamport_initialize_kernel<<<grid_size, 1024, 0, stream>>>(reinterpret_cast<float*>(ptr), bytes / sizeof(float));
}

Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
Expand All @@ -45,10 +45,11 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
int device_id;
TLLM_CUDA_CHECK(cudaGetDevice(&device_id));
m_buffer_mgr = std::make_shared<tensorrt_llm::runtime::BufferManager>(m_cuda_stream);
int buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
int flag_size = tp_size * kBarrierFlagCount * sizeof(int);
int lamport_comm_size = tp_size * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
int lamport_buffer_size = 3 * lamport_comm_size;
size_t buffer_size = tp_size * max_token_num * hidden_dim * sizeof(half);
size_t flag_size = tp_size * kBarrierFlagCount * sizeof(int);
size_t lamport_comm_size
= static_cast<size_t>(tp_size) * std::max(kOneShotMaxToken, max_token_num) * hidden_dim * sizeof(half);
size_t lamport_buffer_size = 3 * lamport_comm_size;
for (auto size : {buffer_size, flag_size, lamport_buffer_size})
{
m_ipc_mem_handles.emplace_back(size, *m_buffer_mgr, m_world_config, p2p_supported);
Expand All @@ -61,20 +62,20 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
workspace.push_back(ipc_mem_handle.getCommPtrs()[r]);
}
}
// atomic flag read counter
// kernel_flag_ptr[0] = 0;
// non-lamport flag
// kernel_flag_ptr[1] = 0;
// lamport flag
// kernel_flag_ptr[2] = 0;
// lamport triple buffer offset
// kernel_flag_ptr[3] = lamport_comm_size;
// lamport clear size
// kernel_flag_ptr[4] = 0;
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 5 * sizeof(int)));
std::vector<int> h_data{0, 0, 0, lamport_comm_size, 0};
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_data.data(), 5 * sizeof(int), cudaMemcpyHostToDevice));
// flag_buffer[0], atomic flag read counter
// flag_buffer[1], non-lamport flag
// flag_buffer[2], lamport flag
TLLM_CUDA_CHECK(cudaMalloc(&m_flag_d_ptr, 3 * sizeof(int)));
std::vector<int> h_flag_data{0, 0, 0};
TLLM_CUDA_CHECK(cudaMemcpy(m_flag_d_ptr, h_flag_data.data(), 3 * sizeof(int), cudaMemcpyHostToDevice));
workspace.push_back(m_flag_d_ptr);
// layout_buffer[0], clear size for next lamport kernel
// layout_buffer[1], triple buffer offset for lamport kernel
TLLM_CUDA_CHECK(cudaMalloc(&m_layout_d_ptr, 2 * sizeof(int64_t)));
std::vector<int64_t> h_layout_data{0, static_cast<int64_t>(lamport_comm_size)};
TLLM_CUDA_CHECK(cudaMemcpy(m_layout_d_ptr, h_layout_data.data(), 2 * sizeof(int64_t), cudaMemcpyHostToDevice));
workspace.push_back(m_layout_d_ptr);

TLLM_CUDA_CHECK(cudaMalloc(&m_workspace, workspace.size() * sizeof(void*)));
TLLM_CUDA_CHECK(
cudaMemcpy(m_workspace, workspace.data(), workspace.size() * sizeof(void*), cudaMemcpyHostToDevice));
Expand All @@ -87,6 +88,10 @@ Workspace::~Workspace()
{
TLLM_CUDA_CHECK(cudaFree(m_flag_d_ptr));
}
if (m_layout_d_ptr)
{
TLLM_CUDA_CHECK(cudaFree(m_layout_d_ptr));
}
if (m_workspace)
{
TLLM_CUDA_CHECK(cudaFree(m_workspace));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,10 @@ class Workspace
void* m_workspace;
std::shared_ptr<tensorrt_llm::runtime::CudaStream> m_cuda_stream;
void* m_flag_d_ptr;
void* m_layout_d_ptr;
};

void lamport_initialize(void* ptr, int bytes, cudaStream_t stream);
void lamport_initialize(void* ptr, size_t bytes, cudaStream_t stream);
} // namespace kernels::ar_fusion

TRTLLM_NAMESPACE_END
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ struct LamportComm
{
counter_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[0];
flag_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[2];
clear_ptr = &reinterpret_cast<int*>(workspace[NRanks * 3])[4];
clear_ptr = &reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[0];
flag_value = *flag_ptr;
int comm_size = reinterpret_cast<int*>(workspace[NRanks * 3])[3];
auto comm_size = reinterpret_cast<int64_t*>(workspace[NRanks * 3 + 1])[1];
clear_size = *clear_ptr;
int data_offset = flag_value % 3;
int clear_offset = (flag_value + 2) % 3;
Expand All @@ -49,7 +49,7 @@ struct LamportComm
}
}

__device__ __forceinline__ void update(int new_clear_size)
__device__ __forceinline__ void update(int64_t new_clear_size)
{
if (blockIdx.x == 0 && threadIdx.x == 0)
{
Expand All @@ -64,10 +64,10 @@ struct LamportComm

int* counter_ptr;
int* flag_ptr;
int* clear_ptr;
int64_t* clear_ptr;
uint8_t* data_bufs[NRanks];
uint8_t* clear_buf;
int clear_size;
int64_t clear_size;
int flag_value;
};

Expand Down
21 changes: 16 additions & 5 deletions tensorrt_llm/plugin/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,14 +737,25 @@ def allocate_allreduce_fusion_workspace(
lamport_buffers.local_ptr,
3 * lamport_buffers_size,
)
flag_buffer = torch.tensor([0, 0, 0, lamport_buffers_size, 0],
dtype=torch.int,
device="cuda")
buffers = [ipc_buffers, ipc_barriers, lamport_buffers, flag_buffer]
# flag_buffer[0], atomic flag read counter
# flag_buffer[1], non-lamport flag
# flag_buffer[2], lamport flag
flag_buffer = torch.tensor([0, 0, 0], dtype=torch.int, device="cuda")
# layout_buffer[0], clear size for next lamport kernel
# layout_buffer[1], triple buffer offset for lamport kernel
layout_buffer = torch.tensor([0, lamport_buffers_size],
dtype=torch.int64,
device="cuda")

buffers = [
ipc_buffers, ipc_barriers, lamport_buffers, flag_buffer,
layout_buffer
]

return buffers, torch.tensor(
ipc_buffers.serialize() + ipc_barriers.serialize() +
lamport_buffers.serialize() + [flag_buffer.data_ptr()],
lamport_buffers.serialize() + [flag_buffer.data_ptr()] +
[layout_buffer.data_ptr()],
dtype=torch.int64,
device="cuda")

Expand Down