Skip to content

Commit 76795cb

Browse files
authored
[misc] megatron grpo support non-padding-free (#7218)
1 parent 8c4c027 commit 76795cb

File tree

3 files changed

+158
-78
lines changed

3 files changed

+158
-78
lines changed

swift/megatron/trainers/grpo_trainer.py

Lines changed: 102 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -297,30 +297,49 @@ def _get_encoded_batch(rollout_batch):
297297
encoded_list = [template.encode(data, return_length=True) for data in rollout_batch]
298298
encoded_batch = to_device(
299299
template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device)
300-
if 'cu_seq_lens_q' in encoded_batch:
301-
cu_seq_lens_q = encoded_batch['cu_seq_lens_q']
302-
else:
303-
cu_seq_lens_q = get_packed_seq_params(encoded_batch['position_ids'])['cu_seq_lens_q']
304-
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
305300

306301
labels = encoded_batch['labels']
307302
batch_size = len(rollout_batch)
308-
max_seq_len = seq_lengths.max().item()
309-
assert self.template.padding_free
310303

311304
truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch],
312305
dtype=torch.bool,
313306
device=self.device)
314307

315-
# completion_mask in rmpad format [1, total_tokens]
316-
completion_mask_rmpad = (labels != -100).float()
317-
completion_mask, _ = pad_logps_back_to_batch(
318-
logps_rmpad=completion_mask_rmpad,
319-
logits_to_keep=max_seq_len,
320-
batch_size=batch_size,
321-
seq_lengths=seq_lengths,
322-
pad_value=0.0)
323-
completion_mask = completion_mask.bool()
308+
if self.template.padding_free:
309+
# In padding_free mode, labels shape is [1, total_seq_len] (rmpad format)
310+
# Calculate seq_lengths from cu_seq_lens or position_ids
311+
if 'cu_seq_lens_q' in encoded_batch:
312+
cu_seq_lens_q = encoded_batch['cu_seq_lens_q']
313+
else:
314+
cu_seq_lens_q = get_packed_seq_params(encoded_batch['position_ids'])['cu_seq_lens_q']
315+
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
316+
max_seq_len = seq_lengths.max().item()
317+
318+
# completion_mask in rmpad format [1, total_tokens]
319+
completion_mask_rmpad = (labels != -100).float()
320+
completion_mask, _ = pad_logps_back_to_batch(
321+
logps_rmpad=completion_mask_rmpad,
322+
logits_to_keep=max_seq_len,
323+
batch_size=batch_size,
324+
seq_lengths=seq_lengths,
325+
pad_value=0.0)
326+
completion_mask = completion_mask.bool()
327+
else:
328+
# In non-padding_free mode, labels shape is [batch_size, seq_len] (batch format)
329+
# Calculate seq_lengths from attention_mask
330+
attention_mask = encoded_batch.get('attention_mask')
331+
if attention_mask is not None:
332+
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
333+
if attention_mask.dim() == 4:
334+
attention_mask = attention_mask[:, 0, 0, :]
335+
seq_lengths = attention_mask.sum(dim=-1).to(torch.int64)
336+
else:
337+
# Fallback: assume full sequence length for each sample
338+
seq_lengths = torch.full((batch_size, ), labels.shape[-1], dtype=torch.int64, device=self.device)
339+
max_seq_len = labels.shape[-1]
340+
341+
# completion_mask is already [batch_size, seq_len] in non-padding_free mode
342+
completion_mask = (labels != -100)
324343

325344
encoded_batch.update({
326345
'completion_mask': completion_mask, # [batch_size, max_seq_len]
@@ -400,10 +419,10 @@ def _get_encoded_batch(rollout_batch):
400419

401420
if self.loss_type in ['cispo', 'dapo']:
402421
# Calculate num_items_in_batch
403-
# Count tokens from all mini_batch_data (this includes gathered data from rollout_group)
404-
total_token_count = sum(batch_data['seq_lengths'].sum().item() if self.template.
405-
padding_free else batch_data['completion_mask'].sum().item()
406-
for batch_data in mini_batch_data)
422+
# Count completion tokens from all mini_batch_data (this includes gathered data from rollout_group)
423+
# Use completion_mask.sum() for both padding_free and non-padding_free modes
424+
# since we want the count of actual completion tokens, not sequence lengths
425+
total_token_count = sum(batch_data['completion_mask'].sum().item() for batch_data in mini_batch_data)
407426

408427
# All-reduce across all ranks
409428
total_token_count_tensor = torch.tensor(total_token_count, dtype=torch.int, device=self.device)
@@ -873,22 +892,34 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
873892
with torch.no_grad(), self.null_ref_context() as ref_models:
874893
assert len(ref_models) == 1, 'GRPO currently does not support VPP.'
875894
ref_model = ref_models[0]
876-
ref_per_token_logps_rmpad = self.model_forward(
895+
ref_per_token_logps_raw = self.model_forward(
877896
ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
878-
ref_per_token_logps, _ = pad_logps_back_to_batch(
879-
logps_rmpad=ref_per_token_logps_rmpad,
880-
logits_to_keep=max_seq_len,
881-
batch_size=batch_size,
882-
seq_lengths=seq_lengths)
897+
if self.template.padding_free:
898+
# In padding_free mode, logps are in rmpad format [1, total_tokens]
899+
# Pad to batch format [batch_size, max_seq_len]
900+
ref_per_token_logps, _ = pad_logps_back_to_batch(
901+
logps_rmpad=ref_per_token_logps_raw,
902+
logits_to_keep=max_seq_len,
903+
batch_size=batch_size,
904+
seq_lengths=seq_lengths)
905+
else:
906+
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
907+
ref_per_token_logps = ref_per_token_logps_raw
883908
batch['ref_per_token_logps'] = ref_per_token_logps
884909

885-
old_per_token_logps_rmpad = self.model_forward(
910+
old_per_token_logps_raw = self.model_forward(
886911
self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
887-
old_per_token_logps, _ = pad_logps_back_to_batch(
888-
logps_rmpad=old_per_token_logps_rmpad,
889-
logits_to_keep=max_seq_len,
890-
batch_size=batch_size,
891-
seq_lengths=seq_lengths)
912+
if self.template.padding_free:
913+
# In padding_free mode, logps are in rmpad format [1, total_tokens]
914+
# Pad to batch format [batch_size, max_seq_len]
915+
old_per_token_logps, _ = pad_logps_back_to_batch(
916+
logps_rmpad=old_per_token_logps_raw,
917+
logits_to_keep=max_seq_len,
918+
batch_size=batch_size,
919+
seq_lengths=seq_lengths)
920+
else:
921+
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
922+
old_per_token_logps = old_per_token_logps_raw
892923
batch['old_per_token_logps'] = old_per_token_logps
893924

894925
return batch
@@ -985,7 +1016,16 @@ def build_pretraining_data_loader(*_args, **kwargs):
9851016
def forward_step(self, data_iterator, model):
9861017
# train_batch_size
9871018
# return: output_tensor, loss_func
988-
data = self.get_batch(data_iterator)
1019+
data = next(data_iterator)
1020+
advantages = data.pop('advantages')
1021+
truncated_mask = data.pop('truncated_mask')
1022+
seq_lengths = data.pop('seq_lengths')
1023+
data = self._prepare_batch(data)
1024+
data.update({
1025+
'advantages': advantages,
1026+
'truncated_mask': truncated_mask,
1027+
'seq_lengths': seq_lengths,
1028+
})
9891029
data.pop('loss_scale', None)
9901030
inputs = self._prepare_model_inputs(data)
9911031

@@ -995,29 +1035,36 @@ def forward_step(self, data_iterator, model):
9951035

9961036
@profiling_decorator
9971037
def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
1038+
args = get_args()
9981039
# Get pre-padded data in batch format [batch_size, max_seq_len]
9991040
advantages = data['advantages'] # [batch_size]
10001041
labels = data['labels']
10011042
completion_mask = data['completion_mask'] # [batch_size, max_seq_len]
1002-
packed_seq_params = data['packed_seq_params']
1043+
packed_seq_params = data.get('packed_seq_params')
10031044
truncated_mask = data['truncated_mask'] # [batch_size]
10041045
seq_lengths = data['seq_lengths'] # [batch_size]
10051046
max_seq_len = completion_mask.shape[1]
10061047
micro_batch_size = self.micro_batch_size
10071048

1008-
# Use full sequence lengths directly (get_logps returns full sequences in CP mode)
1009-
lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size
1010-
+ 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size]
1011-
1012-
# get_logps with per_token=True returns rmpad format [1, total_tokens]
1013-
# Pad to batch format [batch_size, max_seq_len]
1014-
per_token_logps_rmpad = self.get_logps(
1015-
output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True)
1016-
per_token_logps, _ = pad_logps_back_to_batch(
1017-
logps_rmpad=per_token_logps_rmpad,
1018-
logits_to_keep=max_seq_len,
1019-
batch_size=micro_batch_size,
1020-
seq_lengths=seq_lengths)
1049+
if args.padding_free:
1050+
# Use full sequence lengths directly (get_logps returns full sequences in CP mode)
1051+
lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size
1052+
+ 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size]
1053+
1054+
# get_logps with per_token=True returns rmpad format [1, total_tokens]
1055+
# Pad to batch format [batch_size, max_seq_len]
1056+
per_token_logps_rmpad = self.get_logps(
1057+
output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True)
1058+
per_token_logps, _ = pad_logps_back_to_batch(
1059+
logps_rmpad=per_token_logps_rmpad,
1060+
logits_to_keep=max_seq_len,
1061+
batch_size=micro_batch_size,
1062+
seq_lengths=seq_lengths)
1063+
else:
1064+
# In non-padding_free mode, get_logps with per_token=True returns [batch_size, seq_len]
1065+
# No need to pad, already in batch format
1066+
lengths = seq_lengths
1067+
per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params, micro_batch_size, per_token=True)
10211068

10221069
# Get pre-padded ref/old/rollout logps from data
10231070
ref_per_token_logps = data.get('ref_per_token_logps') # [batch_size, max_seq_len] or None
@@ -1256,13 +1303,19 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
12561303
with self.stimer(bdata=True):
12571304
data = self.get_batch(data_iterator)
12581305
data.pop('loss_scale', None)
1306+
input_ids = data.get('input_ids')
12591307
labels = data.get('labels')
12601308
context = torch.no_grad() if no_grad else nullcontext()
12611309
with context:
12621310
output_tensor = forward_step_helper(model, data)
1263-
packed_seq_params = data['packed_seq_params']
1311+
# packed_seq_params only exists in padding_free mode
1312+
packed_seq_params = data.get('packed_seq_params')
1313+
if packed_seq_params is not None:
1314+
num_samples = packed_seq_params.num_samples
1315+
else:
1316+
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]
12641317
data['logps'] = None if labels is None else self.get_logps(
1265-
output_tensor, labels, data['packed_seq_params'], packed_seq_params.num_samples, per_token=per_token)
1318+
output_tensor, labels, packed_seq_params, num_samples, per_token=per_token)
12661319
return data
12671320

12681321
def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]:

swift/megatron/trainers/rlhf_mixin.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
8686
Works for both logps (float) and masks (bool/int).
8787
8888
Args:
89-
tensor: [1, packed_len/cp_size] - CP-split tensor (any dtype)
90-
packed_seq_params: PackedSeqParams object
89+
tensor: [1, packed_len/cp_size] in padding_free mode, or [batch_size, seq_len/cp_size] otherwise
90+
packed_seq_params: PackedSeqParams object (None in non-padding_free mode)
9191
num_samples: Number of samples in the batch
9292
9393
Returns:
94-
output_full: [1, packed_len] - Full sequence tensor
94+
output_full: [1, packed_len] in padding_free mode, or [batch_size, seq_len] otherwise
9595
"""
9696
args = get_args()
9797
cp_size = args.context_parallel_size
@@ -102,36 +102,61 @@ def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
102102
torch.distributed.all_gather(output_list, tensor.contiguous(), group=mpu.get_context_parallel_group())
103103
output_list[cp_rank] = tensor
104104

105-
# Reconstruct full sequence
106-
# Shape: [1, packed_len/cp_size] -> [1, packed_len]
107-
cu_seqlens_full = packed_seq_params.cu_seqlens_q
108-
cu_seqlens_cp = cu_seqlens_full // cp_size
105+
if packed_seq_params is not None:
106+
# padding_free mode: [1, packed_len/cp_size] -> [1, packed_len]
107+
cu_seqlens_full = packed_seq_params.cu_seqlens_q
108+
cu_seqlens_cp = cu_seqlens_full // cp_size
109109

110-
# Calculate total packed length
111-
total_packed_len = cu_seqlens_full[num_samples].item()
112-
output_full = tensor.new_zeros(1, total_packed_len)
110+
# Calculate total packed length
111+
total_packed_len = cu_seqlens_full[num_samples].item()
112+
output_full = tensor.new_zeros(1, total_packed_len)
113113

114-
# Reconstruct each sequence
115-
for i in range(num_samples):
116-
start_full = cu_seqlens_full[i].item()
117-
end_full = cu_seqlens_full[i + 1].item()
118-
seq_len = end_full - start_full
114+
# Reconstruct each sequence
115+
for i in range(num_samples):
116+
start_full = cu_seqlens_full[i].item()
117+
end_full = cu_seqlens_full[i + 1].item()
118+
seq_len = end_full - start_full
119+
120+
# Length of each chunk after CP split
121+
chunk_len = seq_len // cp_size
122+
half_chunk = chunk_len // 2
123+
124+
# Concatenate from each CP rank's output (load-balanced split)
125+
for j in range(cp_size):
126+
o = output_list[j][0]
127+
start_cp = cu_seqlens_cp[i].item()
128+
129+
# Get two half chunks (CP's load-balanced split)
130+
o0 = o[start_cp:start_cp + half_chunk]
131+
o1 = o[start_cp + half_chunk:start_cp + chunk_len]
132+
133+
# Place back to full sequence
134+
output_full[0, start_full + j * half_chunk:start_full + (j + 1) * half_chunk] = o0
135+
output_full[0, end_full - (j + 1) * half_chunk:end_full - j * half_chunk] = o1
136+
else:
137+
# non-padding_free mode: [batch_size, seq_len/cp_size] -> [batch_size, seq_len]
138+
# Each CP rank has chunks split with load-balanced pattern (2*cp_size chunks)
139+
batch_size = tensor.shape[0]
140+
seq_len_per_cp = tensor.shape[1]
141+
full_seq_len = seq_len_per_cp * cp_size
119142

120-
# Length of each chunk after CP split
121-
chunk_len = seq_len // cp_size
122-
half_chunk = chunk_len // 2
143+
output_full = tensor.new_zeros(batch_size, full_seq_len)
123144

124-
# Concatenate from each CP rank's output (load-balanced split)
125-
for j in range(cp_size):
126-
o = output_list[j][0]
127-
start_cp = cu_seqlens_cp[i].item()
145+
# Each CP rank j holds chunks j and (2*cp_size - j - 1) from the original 2*cp_size split
146+
# Reconstruct the full sequence by placing chunks back in correct positions
147+
chunk_len = full_seq_len // (2 * cp_size)
128148

129-
# Get two half chunks (CP's load-balanced split)
130-
o0 = o[start_cp:start_cp + half_chunk]
131-
o1 = o[start_cp + half_chunk:start_cp + chunk_len]
132-
133-
# Place back to full sequence
134-
output_full[0, start_full + j * half_chunk:start_full + (j + 1) * half_chunk] = o0
135-
output_full[0, end_full - (j + 1) * half_chunk:end_full - j * half_chunk] = o1
149+
for j in range(cp_size):
150+
o = output_list[j] # [batch_size, seq_len_per_cp]
151+
# This rank holds 2 chunks: chunk j and chunk (2*cp_size - j - 1)
152+
half_len = seq_len_per_cp // 2
153+
o0 = o[:, :half_len] # First half -> chunk j
154+
o1 = o[:, half_len:] # Second half -> chunk (2*cp_size - j - 1)
155+
156+
# Place chunk j at position j * chunk_len
157+
output_full[:, j * chunk_len:(j + 1) * chunk_len] = o0
158+
# Place chunk (2*cp_size - j - 1) at position (2*cp_size - j - 1) * chunk_len
159+
reverse_idx = 2 * cp_size - j - 1
160+
output_full[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o1
136161

137162
return output_full

swift/megatron/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ def forward_step_helper(model, inputs, dtype=None):
287287
args = get_args()
288288
if mpu.is_pipeline_first_stage():
289289
micro_batch_size = 1 # use qkv_format 'thd'
290+
if not args.padding_free:
291+
micro_batch_size = args.micro_batch_size
290292
seq_length = inputs['position_ids'].shape[-1]
291293
if args.sequence_parallel:
292294
seq_length //= mpu.get_tensor_model_parallel_world_size()

0 commit comments

Comments
 (0)