@@ -21,18 +21,18 @@ TRTLLM_NAMESPACE_BEGIN
2121namespace kernels ::ar_fusion
2222{
2323
24- __global__ void lamport_initialize_kernel (float * ptr, int size)
24+ __global__ void lamport_initialize_kernel (float * ptr, size_t size)
2525{
26- int idx = blockIdx .x * blockDim .x + threadIdx .x ;
26+ size_t idx = static_cast < size_t >( blockIdx .x ) * blockDim .x + threadIdx .x ;
2727 if (idx >= size)
2828 return ;
2929 ptr[idx] = -0 .f ;
3030}
3131
32- void lamport_initialize (void * ptr, int bytes, cudaStream_t stream)
32+ void lamport_initialize (void * ptr, size_t bytes, cudaStream_t stream)
3333{
34- int grid_size = ( bytes + 127 ) / 128 ;
35- lamport_initialize_kernel<<<grid_size, 128 , 0 , stream>>> (reinterpret_cast <float *>(ptr), bytes / sizeof (float ));
34+ int grid_size = static_cast < int >(( bytes + 1023 ) / 1024 ) ;
35+ lamport_initialize_kernel<<<grid_size, 1024 , 0 , stream>>> (reinterpret_cast <float *>(ptr), bytes / sizeof (float ));
3636}
3737
3838Workspace::Workspace (int rank, int tp_size, int max_token_num, int hidden_dim,
@@ -45,10 +45,11 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
4545 int device_id;
4646 TLLM_CUDA_CHECK (cudaGetDevice (&device_id));
4747 m_buffer_mgr = std::make_shared<tensorrt_llm::runtime::BufferManager>(m_cuda_stream);
48- int buffer_size = tp_size * max_token_num * hidden_dim * sizeof (half);
49- int flag_size = tp_size * kBarrierFlagCount * sizeof (int );
50- int lamport_comm_size = tp_size * std::max (kOneShotMaxToken , max_token_num) * hidden_dim * sizeof (half);
51- int lamport_buffer_size = 3 * lamport_comm_size;
48+ size_t buffer_size = tp_size * max_token_num * hidden_dim * sizeof (half);
49+ size_t flag_size = tp_size * kBarrierFlagCount * sizeof (int );
50+ size_t lamport_comm_size
51+ = static_cast <size_t >(tp_size) * std::max (kOneShotMaxToken , max_token_num) * hidden_dim * sizeof (half);
52+ size_t lamport_buffer_size = 3 * lamport_comm_size;
5253 for (auto size : {buffer_size, flag_size, lamport_buffer_size})
5354 {
5455 m_ipc_mem_handles.emplace_back (size, *m_buffer_mgr, m_world_config, p2p_supported);
@@ -61,20 +62,20 @@ Workspace::Workspace(int rank, int tp_size, int max_token_num, int hidden_dim,
6162 workspace.push_back (ipc_mem_handle.getCommPtrs ()[r]);
6263 }
6364 }
64- // atomic flag read counter
65- // kernel_flag_ptr[0] = 0;
66- // non-lamport flag
67- // kernel_flag_ptr[1] = 0;
68- // lamport flag
69- // kernel_flag_ptr[2] = 0;
70- // lamport triple buffer offset
71- // kernel_flag_ptr[3] = lamport_comm_size;
72- // lamport clear size
73- // kernel_flag_ptr[4] = 0;
74- TLLM_CUDA_CHECK (cudaMalloc (&m_flag_d_ptr, 5 * sizeof (int )));
75- std::vector<int > h_data{0 , 0 , 0 , lamport_comm_size, 0 };
76- TLLM_CUDA_CHECK (cudaMemcpy (m_flag_d_ptr, h_data.data (), 5 * sizeof (int ), cudaMemcpyHostToDevice));
65+ // flag_buffer[0], atomic flag read counter
66+ // flag_buffer[1], non-lamport flag
67+ // flag_buffer[2], lamport flag
68+ TLLM_CUDA_CHECK (cudaMalloc (&m_flag_d_ptr, 3 * sizeof (int )));
69+ std::vector<int > h_flag_data{0 , 0 , 0 };
70+ TLLM_CUDA_CHECK (cudaMemcpy (m_flag_d_ptr, h_flag_data.data (), 3 * sizeof (int ), cudaMemcpyHostToDevice));
7771 workspace.push_back (m_flag_d_ptr);
72+ // layout_buffer[0], clear size for next lamport kernel
73+ // layout_buffer[1], triple buffer offset for lamport kernel
74+ TLLM_CUDA_CHECK (cudaMalloc (&m_layout_d_ptr, 2 * sizeof (int64_t )));
75+ std::vector<int64_t > h_layout_data{0 , static_cast <int64_t >(lamport_comm_size)};
76+ TLLM_CUDA_CHECK (cudaMemcpy (m_layout_d_ptr, h_layout_data.data (), 2 * sizeof (int64_t ), cudaMemcpyHostToDevice));
77+ workspace.push_back (m_layout_d_ptr);
78+
7879 TLLM_CUDA_CHECK (cudaMalloc (&m_workspace, workspace.size () * sizeof (void *)));
7980 TLLM_CUDA_CHECK (
8081 cudaMemcpy (m_workspace, workspace.data (), workspace.size () * sizeof (void *), cudaMemcpyHostToDevice));
@@ -87,6 +88,10 @@ Workspace::~Workspace()
8788 {
8889 TLLM_CUDA_CHECK (cudaFree (m_flag_d_ptr));
8990 }
91+ if (m_layout_d_ptr)
92+ {
93+ TLLM_CUDA_CHECK (cudaFree (m_layout_d_ptr));
94+ }
9095 if (m_workspace)
9196 {
9297 TLLM_CUDA_CHECK (cudaFree (m_workspace));
0 commit comments