Skip to content

Commit 8564988

Browse files
committed
rename fields
1 parent e7cf5b2 commit 8564988

File tree

2 files changed

+54
-54
lines changed

2 files changed

+54
-54
lines changed

cpp/tensorrt_llm/thop/alltoallOp.cpp

Lines changed: 48 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -115,102 +115,102 @@ std::vector<torch::Tensor> alltoall_helix(
115115
/**
116116
* Helix All-to-All operation with two fields.
117117
*
118-
* Input tensors have shape [..., cp_size, kv_lora_rank] for field0 and [...,
119-
* cp_size, 2] for field1. The operation exchanges data along the cp_size
118+
* Input tensors have shape [..., cp_size, kv_lora_rank] for partial_o and [...,
119+
* cp_size, 2] for softmax_stats. The operation exchanges data along the cp_size
120120
* dimension across all ranks.
121121
*
122-
* @param field0 Field 0 tensor (half precision, shape [..., cp_size,
122+
* @param partial_o Field 0 tensor (half precision, shape [..., cp_size,
123123
* kv_lora_rank])
124-
* @param field1 Field 1 tensor (float32, shape [..., cp_size, 2])
124+
* @param softmax_stats Field 1 tensor (float32, shape [..., cp_size, 2])
125125
* @param workspace Workspace tensor (uint64, strided across ranks)
126126
* @param cp_rank Current context parallel rank
127127
* @param cp_size Total number of context parallel ranks
128-
* @return tuple of (field0_out, field1_out) with same shapes as inputs
128+
* @return tuple of (partial_o_out, softmax_stats_out) with same shapes as inputs
129129
*/
130130
std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
131-
torch::Tensor field0, torch::Tensor field1, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
131+
torch::Tensor partial_o, torch::Tensor softmax_stats, torch::Tensor workspace, int64_t cp_rank, int64_t cp_size)
132132
{
133133

134134
// Input validation
135-
CHECK_TH_CUDA(field0);
136-
CHECK_TH_CUDA(field1);
135+
CHECK_TH_CUDA(partial_o);
136+
CHECK_TH_CUDA(softmax_stats);
137137
CHECK_TH_CUDA(workspace);
138-
CHECK_CONTIGUOUS(field0);
139-
CHECK_CONTIGUOUS(field1);
138+
CHECK_CONTIGUOUS(partial_o);
139+
CHECK_CONTIGUOUS(softmax_stats);
140140

141141
// Type checks
142-
TORCH_CHECK(field0.scalar_type() == at::ScalarType::Half || field0.scalar_type() == at::ScalarType::BFloat16,
143-
"field0 must be half or bfloat16");
144-
CHECK_TYPE(field1, at::ScalarType::Float);
142+
TORCH_CHECK(partial_o.scalar_type() == at::ScalarType::Half || partial_o.scalar_type() == at::ScalarType::BFloat16,
143+
"partial_o must be half or bfloat16");
144+
CHECK_TYPE(softmax_stats, at::ScalarType::Float);
145145
CHECK_TYPE(workspace, at::ScalarType::UInt64);
146146

147147
// Shape validation
148-
TORCH_CHECK(field0.dim() >= 2, "field0 must have at least 2 dimensions");
149-
TORCH_CHECK(field1.dim() >= 2, "field1 must have at least 2 dimensions");
150-
TORCH_CHECK(field0.dim() == field1.dim(), "field0 and field1 must have same number of dimensions");
148+
TORCH_CHECK(partial_o.dim() >= 2, "partial_o must have at least 2 dimensions");
149+
TORCH_CHECK(softmax_stats.dim() >= 2, "softmax_stats must have at least 2 dimensions");
150+
TORCH_CHECK(partial_o.dim() == softmax_stats.dim(), "partial_o and softmax_stats must have same number of dimensions");
151151

152152
// Get dimensions
153-
int kv_lora_rank = field0.size(-1);
154-
TORCH_CHECK(field0.size(-2) == cp_size && field1.size(-2) == cp_size,
155-
"field0/1 second-to-last dimension must equal cp_size");
153+
int kv_lora_rank = partial_o.size(-1);
154+
TORCH_CHECK(partial_o.size(-2) == cp_size && softmax_stats.size(-2) == cp_size,
155+
"partial_o/softmax_stats second-to-last dimension must equal cp_size");
156156
TORCH_CHECK(
157-
field1.size(-1) % 2 == 0 && field1.size(-1) >= 2, "field1 last dimension must be divisible by 2 (float2)");
158-
bool allowVariableField1 = field1.size(-1) > 2;
157+
softmax_stats.size(-1) % 2 == 0 && softmax_stats.size(-1) >= 2, "softmax_stats last dimension must be divisible by 2 (float2)");
158+
bool allowVariableField1 = softmax_stats.size(-1) > 2;
159159

160160
// Check that leading dimensions match
161-
for (int i = 0; i < field0.dim() - 2; i++)
161+
for (int i = 0; i < partial_o.dim() - 2; i++)
162162
{
163163
TORCH_CHECK(
164-
field0.size(i) == field1.size(i), "field0 and field1 must have matching dimensions except last two");
164+
partial_o.size(i) == softmax_stats.size(i), "partial_o and softmax_stats must have matching dimensions except last two");
165165
}
166-
TORCH_CHECK(field0.size(-1) * field0.element_size() % 16 == 0, "field0 must be aligned to 16 bytes");
166+
TORCH_CHECK(partial_o.size(-1) * partial_o.element_size() % 16 == 0, "partial_o must be aligned to 16 bytes");
167167

168168
TORCH_CHECK(workspace.dim() == 2, "workspace must be 2D (strided across ranks)");
169169
TORCH_CHECK(workspace.size(0) == cp_size, "workspace must have cp_size rows");
170170

171171
// Calculate entry count (product of all dimensions before cp_size)
172172
// This is the number of entries to process per peer rank
173173
int entry_count = 1;
174-
for (int i = 0; i < field0.dim() - 2; i++)
174+
for (int i = 0; i < partial_o.dim() - 2; i++)
175175
{
176-
entry_count *= field0.size(i);
176+
entry_count *= partial_o.size(i);
177177
}
178178

179179
// Reshape to 3D: [entry_count, cp_size, feature_dim]
180-
torch::Tensor field0_3d = field0.reshape({entry_count, cp_size, kv_lora_rank});
181-
torch::Tensor field1_3d = field1.reshape({entry_count, cp_size, field1.size(-1)});
180+
torch::Tensor partial_o_3d = partial_o.reshape({entry_count, cp_size, kv_lora_rank});
181+
torch::Tensor softmax_stats_3d = softmax_stats.reshape({entry_count, cp_size, softmax_stats.size(-1)});
182182

183183
// Allocate output tensors (same shape as input)
184-
torch::Tensor field0_out = torch::empty_like(field0);
185-
torch::Tensor field1_out = torch::empty_like(field1);
184+
torch::Tensor partial_o_out = torch::empty_like(partial_o);
185+
torch::Tensor softmax_stats_out = torch::empty_like(softmax_stats);
186186

187-
torch::Tensor field0_out_3d = field0_out.reshape({entry_count, cp_size, kv_lora_rank});
188-
torch::Tensor field1_out_3d = field1_out.reshape({entry_count, cp_size, field1.size(-1)});
187+
torch::Tensor partial_o_out_3d = partial_o_out.reshape({entry_count, cp_size, kv_lora_rank});
188+
torch::Tensor softmax_stats_out_3d = softmax_stats_out.reshape({entry_count, cp_size, softmax_stats.size(-1)});
189189

190190
// Setup parameters
191191
tensorrt_llm::kernels::HelixAllToAllParams params;
192192

193193
// Field 0 (variable size half)
194-
params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(field0_3d.data_ptr());
194+
params.sendFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_3d.data_ptr());
195195
params.sendFields[0].elementCount = kv_lora_rank;
196-
params.sendFields[0].elementSize = field0.element_size();
197-
params.sendFields[0].stride = field0_3d.stride(1) * field0.element_size();
196+
params.sendFields[0].elementSize = partial_o.element_size();
197+
params.sendFields[0].stride = partial_o_3d.stride(1) * partial_o.element_size();
198198

199-
params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(field0_out_3d.data_ptr());
199+
params.recvFields[0].dataPtr = reinterpret_cast<uint8_t*>(partial_o_out_3d.data_ptr());
200200
params.recvFields[0].elementCount = kv_lora_rank;
201-
params.recvFields[0].elementSize = field0.element_size();
202-
params.recvFields[0].stride = field0_out_3d.stride(1) * field0.element_size();
201+
params.recvFields[0].elementSize = partial_o.element_size();
202+
params.recvFields[0].stride = partial_o_out_3d.stride(1) * partial_o.element_size();
203203

204204
// Field 1 (single float2)
205-
params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(field1_3d.data_ptr<float>());
206-
params.sendFields[1].elementCount = field1.size(-1);
207-
params.sendFields[1].elementSize = field1.element_size();
208-
params.sendFields[1].stride = field1_3d.stride(1) * field1.element_size();
205+
params.sendFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_3d.data_ptr<float>());
206+
params.sendFields[1].elementCount = softmax_stats.size(-1);
207+
params.sendFields[1].elementSize = softmax_stats.element_size();
208+
params.sendFields[1].stride = softmax_stats_3d.stride(1) * softmax_stats.element_size();
209209

210-
params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(field1_out_3d.data_ptr<float>());
211-
params.recvFields[1].elementCount = field1.size(-1);
212-
params.recvFields[1].elementSize = field1.element_size();
213-
params.recvFields[1].stride = field1_out_3d.stride(1) * field1.element_size();
210+
params.recvFields[1].dataPtr = reinterpret_cast<uint8_t*>(softmax_stats_out_3d.data_ptr<float>());
211+
params.recvFields[1].elementCount = softmax_stats.size(-1);
212+
params.recvFields[1].elementSize = softmax_stats.element_size();
213+
params.recvFields[1].stride = softmax_stats_out_3d.stride(1) * softmax_stats.element_size();
214214

215215
// Entry count and workspace
216216
params.entryCount = entry_count;
@@ -227,7 +227,7 @@ std::tuple<torch::Tensor, torch::Tensor> alltoall_helix_native(
227227
auto stream = at::cuda::getCurrentCUDAStream();
228228
tensorrt_llm::kernels::launchHelixAllToAll(params, allowVariableField1, stream);
229229

230-
return std::make_tuple(field0_out, field1_out);
230+
return std::make_tuple(partial_o_out, softmax_stats_out);
231231
}
232232

233233
/**
@@ -267,7 +267,7 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
267267
{
268268
m.def("alltoall_helix(Tensor[] input_list, int[] group, int? num_lists) -> Tensor[]");
269269
m.def(
270-
"alltoall_helix_native(Tensor field0, Tensor field1, Tensor workspace, int "
270+
"alltoall_helix_native(Tensor partial_o, Tensor softmax_stats, Tensor workspace, int "
271271
"cp_rank, int cp_size) -> (Tensor, Tensor)");
272272
m.def("get_helix_workspace_size_per_rank(Tensor __dummy__, int cp_size) -> int");
273273
m.def(

tensorrt_llm/_torch/modules/attention.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1147,24 +1147,24 @@ def _attn_forward_gen(self, attn_backend: AttentionBackend, q: torch.Tensor,
11471147
# partial_o: [num_tokens, num_heads * kv_lora_rank] -> [num_tokens, cp_size, num_heads_tp_cp, kv_lora_rank]
11481148
# softmax_stats: [num_tokens, num_heads, 2] -> [num_tokens, cp_size, num_heads_tp_cp, 2]
11491149

1150-
field0 = partial_o.view(num_tokens, cp_size,
1150+
partial_o = partial_o.view(num_tokens, cp_size,
11511151
self.num_heads_tp_cp,
11521152
kv_lora_rank).transpose(1,
11531153
2).contiguous()
1154-
field1 = softmax_stats.view(num_tokens, cp_size,
1154+
softmax_stats = softmax_stats.view(num_tokens, cp_size,
11551155
self.num_heads_tp_cp,
11561156
2).transpose(1, 2).contiguous()
11571157

11581158
# Call FIFO-based helixAllToAll.
1159-
field0_out, field1_out = helix.alltoall_native(field0, field1)
1159+
partial_o_out, softmax_stats_out = helix.alltoall_native(partial_o, softmax_stats)
11601160

1161-
# field0_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
1162-
# field1_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
1161+
# partial_o_out: [num_tokens, num_heads_tp_cp, cp_size, kv_lora_rank]
1162+
# softmax_stats_out: [num_tokens, num_heads_tp_cp, cp_size, 2]
11631163
# cp_dim = 2 (the dimension where cp_size is located)
11641164

11651165
# Call helix_post_process_native with cp_dim=2.
11661166
return torch.ops.trtllm.helix_post_process_native(
1167-
field0_out, field1_out, 1.0, 2)
1167+
partial_o_out, softmax_stats_out, 1.0, 2)
11681168
else:
11691169
attn_output = attn_backend.forward(q, k, v, attn_metadata, **kwargs)
11701170
return attn_output

0 commit comments

Comments
 (0)