@@ -115,102 +115,102 @@ std::vector<torch::Tensor> alltoall_helix(
115115/* *
116116 * Helix All-to-All operation with two fields.
117117 *
118- * Input tensors have shape [..., cp_size, kv_lora_rank] for field0 and [...,
119- * cp_size, 2] for field1 . The operation exchanges data along the cp_size
118+ * Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [...,
119+ * cp_size, 2] for softmax_stats . The operation exchanges data along the cp_size
120120 * dimension across all ranks.
121121 *
122- * @param field0 Field 0 tensor (half precision, shape [..., cp_size,
122+ * @param partial_o Field 0 tensor (half precision, shape [..., cp_size,
123123 * kv_lora_rank])
124- * @param field1 Field 1 tensor (float32, shape [..., cp_size, 2])
124+ * @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2])
125125 * @param workspace Workspace tensor (uint64, strided across ranks)
126126 * @param cp_rank Current context parallel rank
127127 * @param cp_size Total number of context parallel ranks
128- * @return tuple of (field0_out, field1_out ) with same shapes as inputs
128+ * @return tuple of (partial_o_out, softmax_stats_out ) with same shapes as inputs
129129 */
130130std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native (
131- torch::Tensor field0 , torch::Tensor field1 , torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
131+ torch::Tensor partial_o , torch::Tensor softmax_stats , torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
132132{
133133
134134 // Input validation
135- CHECK_TH_CUDA (field0 );
136- CHECK_TH_CUDA (field1 );
135+ CHECK_TH_CUDA (partial_o );
136+ CHECK_TH_CUDA (softmax_stats );
137137 CHECK_TH_CUDA (workspace);
138- CHECK_CONTIGUOUS (field0 );
139- CHECK_CONTIGUOUS (field1 );
138+ CHECK_CONTIGUOUS (partial_o );
139+ CHECK_CONTIGUOUS (softmax_stats );
140140
141141 // Type checks
142- TORCH_CHECK (field0 .scalar_type () == at::ScalarType::Half || field0 .scalar_type () == at::ScalarType::BFloat16,
143- " field0 must be half or bfloat16" );
144- CHECK_TYPE (field1 , at::ScalarType::Float);
142+ TORCH_CHECK (partial_o .scalar_type () == at::ScalarType::Half || partial_o .scalar_type () == at::ScalarType::BFloat16,
143+ " partial_o must be half or bfloat16" );
144+ CHECK_TYPE (softmax_stats , at::ScalarType::Float);
145145 CHECK_TYPE (workspace, at::ScalarType::UInt64);
146146
147147 // Shape validation
148- TORCH_CHECK (field0 .dim () >= 2 , " field0 must have at least 2 dimensions" );
149- TORCH_CHECK (field1 .dim () >= 2 , " field1 must have at least 2 dimensions" );
150- TORCH_CHECK (field0 .dim () == field1 .dim (), " field0 and field1 must have same number of dimensions" );
148+ TORCH_CHECK (partial_o .dim () >= 2 , " partial_o must have at least 2 dimensions" );
149+ TORCH_CHECK (softmax_stats .dim () >= 2 , " softmax_stats must have at least 2 dimensions" );
150+ TORCH_CHECK (partial_o .dim () == softmax_stats .dim (), " partial_o and softmax_stats must have same number of dimensions" );
151151
152152 // Get dimensions
153- int kv_lora_rank = field0 .size (-1 );
154- TORCH_CHECK (field0 .size (-2 ) == cp_size && field1 .size (-2 ) == cp_size,
155- " field0/1 second-to-last dimension must equal cp_size" );
153+ int kv_lora_rank = partial_o .size (-1 );
154+ TORCH_CHECK (partial_o .size (-2 ) == cp_size && softmax_stats .size (-2 ) == cp_size,
155+ " partial_o/softmax_stats second-to-last dimension must equal cp_size" );
156156 TORCH_CHECK (
157- field1 .size (-1 ) % 2 == 0 && field1 .size (-1 ) >= 2 , " field1 last dimension must be divisible by 2 (float2)" );
158- bool allowVariableField1 = field1 .size (-1 ) > 2 ;
157+ softmax_stats .size (-1 ) % 2 == 0 && softmax_stats .size (-1 ) >= 2 , " softmax_stats last dimension must be divisible by 2 (float2)" );
158+ bool allowVariableField1 = softmax_stats .size (-1 ) > 2 ;
159159
160160 // Check that leading dimensions match
161- for (int i = 0 ; i < field0 .dim () - 2 ; i++)
161+ for (int i = 0 ; i < partial_o .dim () - 2 ; i++)
162162 {
163163 TORCH_CHECK (
164- field0 .size (i) == field1 .size (i), " field0 and field1 must have matching dimensions except last two" );
164+ partial_o .size (i) == softmax_stats .size (i), " partial_o and softmax_stats must have matching dimensions except last two" );
165165 }
166- TORCH_CHECK (field0 .size (-1 ) * field0 .element_size () % 16 == 0 , " field0 must be aligned to 16 bytes" );
166+ TORCH_CHECK (partial_o .size (-1 ) * partial_o .element_size () % 16 == 0 , " partial_o must be aligned to 16 bytes" );
167167
168168 TORCH_CHECK (workspace.dim () == 2 , " workspace must be 2D (strided across ranks)" );
169169 TORCH_CHECK (workspace.size (0 ) == cp_size, " workspace must have cp_size rows" );
170170
171171 // Calculate entry count (product of all dimensions before cp_size)
172172 // This is the number of entries to process per peer rank
173173 int entry_count = 1 ;
174- for (int i = 0 ; i < field0 .dim () - 2 ; i++)
174+ for (int i = 0 ; i < partial_o .dim () - 2 ; i++)
175175 {
176- entry_count *= field0 .size (i);
176+ entry_count *= partial_o .size (i);
177177 }
178178
179179 // Reshape to 3D: [entry_count, cp_size, feature_dim]
180- torch::Tensor field0_3d = field0 .reshape ({entry_count, cp_size, kv_lora_rank});
181- torch::Tensor field1_3d = field1 .reshape ({entry_count, cp_size, field1 .size (-1 )});
180+ torch::Tensor partial_o_3d = partial_o .reshape ({entry_count, cp_size, kv_lora_rank});
181+ torch::Tensor softmax_stats_3d = softmax_stats .reshape ({entry_count, cp_size, softmax_stats .size (-1 )});
182182
183183 // Allocate output tensors (same shape as input)
184- torch::Tensor field0_out = torch::empty_like (field0 );
185- torch::Tensor field1_out = torch::empty_like (field1 );
184+ torch::Tensor partial_o_out = torch::empty_like (partial_o );
185+ torch::Tensor softmax_stats_out = torch::empty_like (softmax_stats );
186186
187- torch::Tensor field0_out_3d = field0_out .reshape ({entry_count, cp_size, kv_lora_rank});
188- torch::Tensor field1_out_3d = field1_out .reshape ({entry_count, cp_size, field1 .size (-1 )});
187+ torch::Tensor partial_o_out_3d = partial_o_out .reshape ({entry_count, cp_size, kv_lora_rank});
188+ torch::Tensor softmax_stats_out_3d = softmax_stats_out .reshape ({entry_count, cp_size, softmax_stats .size (-1 )});
189189
190190 // Setup parameters
191191 tensorrt_llm::kernels::HelixAllToAllParams params;
192192
193193 // Field 0 (variable size half)
194- params.sendFields [0 ].dataPtr = reinterpret_cast <uint8_t *>(field0_3d .data_ptr ());
194+ params.sendFields [0 ].dataPtr = reinterpret_cast <uint8_t *>(partial_o_3d .data_ptr ());
195195 params.sendFields [0 ].elementCount = kv_lora_rank;
196- params.sendFields [0 ].elementSize = field0 .element_size ();
197- params.sendFields [0 ].stride = field0_3d .stride (1 ) * field0 .element_size ();
196+ params.sendFields [0 ].elementSize = partial_o .element_size ();
197+ params.sendFields [0 ].stride = partial_o_3d .stride (1 ) * partial_o .element_size ();
198198
199- params.recvFields [0 ].dataPtr = reinterpret_cast <uint8_t *>(field0_out_3d .data_ptr ());
199+ params.recvFields [0 ].dataPtr = reinterpret_cast <uint8_t *>(partial_o_out_3d .data_ptr ());
200200 params.recvFields [0 ].elementCount = kv_lora_rank;
201- params.recvFields [0 ].elementSize = field0 .element_size ();
202- params.recvFields [0 ].stride = field0_out_3d .stride (1 ) * field0 .element_size ();
201+ params.recvFields [0 ].elementSize = partial_o .element_size ();
202+ params.recvFields [0 ].stride = partial_o_out_3d .stride (1 ) * partial_o .element_size ();
203203
204204 // Field 1 (single float2)
205- params.sendFields [1 ].dataPtr = reinterpret_cast <uint8_t *>(field1_3d .data_ptr <float >());
206- params.sendFields [1 ].elementCount = field1 .size (-1 );
207- params.sendFields [1 ].elementSize = field1 .element_size ();
208- params.sendFields [1 ].stride = field1_3d .stride (1 ) * field1 .element_size ();
205+ params.sendFields [1 ].dataPtr = reinterpret_cast <uint8_t *>(softmax_stats_3d .data_ptr <float >());
206+ params.sendFields [1 ].elementCount = softmax_stats .size (-1 );
207+ params.sendFields [1 ].elementSize = softmax_stats .element_size ();
208+ params.sendFields [1 ].stride = softmax_stats_3d .stride (1 ) * softmax_stats .element_size ();
209209
210- params.recvFields [1 ].dataPtr = reinterpret_cast <uint8_t *>(field1_out_3d .data_ptr <float >());
211- params.recvFields [1 ].elementCount = field1 .size (-1 );
212- params.recvFields [1 ].elementSize = field1 .element_size ();
213- params.recvFields [1 ].stride = field1_out_3d .stride (1 ) * field1 .element_size ();
210+ params.recvFields [1 ].dataPtr = reinterpret_cast <uint8_t *>(softmax_stats_out_3d .data_ptr <float >());
211+ params.recvFields [1 ].elementCount = softmax_stats .size (-1 );
212+ params.recvFields [1 ].elementSize = softmax_stats .element_size ();
213+ params.recvFields [1 ].stride = softmax_stats_out_3d .stride (1 ) * softmax_stats .element_size ();
214214
215215 // Entry count and workspace
216216 params.entryCount = entry_count;
@@ -227,7 +227,7 @@ std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
227227 auto stream = at::cuda::getCurrentCUDAStream ();
228228 tensorrt_llm::kernels::launchHelixAllToAll (params, allowVariableField1, stream);
229229
230- return std::make_tuple (field0_out, field1_out );
230+ return std::make_tuple (partial_o_out, softmax_stats_out );
231231}
232232
233233/* *
@@ -267,7 +267,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
267267{
268268 m.def (" alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]" );
269269 m.def (
270- " alltoall_helix_native(Tensor field0 , Tensor field1 , Tensor workspace, int "
270+ " alltoall_helix_native(Tensor partial_o , Tensor softmax_stats , Tensor workspace, int "
271271 " cp_rank, int cp_size) -> (Tensor, Tensor)" );
272272 m.def (" get_helix_workspace_size_per_rank(Tensor __dummy__, int cp_size) -> int" );
273273 m.def (
0 commit comments