Skip to content

Commit f4dae57

Browse files
committed
fix
1 parent 45fde49 commit f4dae57

2 files changed

Lines changed: 95 additions & 49 deletions

File tree

swift/megatron/trainers/grpo_trainer.py

Lines changed: 93 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,9 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs):
5353
self.processing_class = self.template.processor
5454
self._prepare_metrics()
5555
self._init_grpo_params()
56+
self._init_rollout_engine()
5657
self._prepare_rewards()
5758
self._prepare_scheduler() # TODO
58-
self._init_rollout_engine() # Use mixin's rollout engine initialization
5959
# Initialize trainer state for reward functions to access training progress
6060
# Will be updated with actual values from Megatron args during training
6161
self.state = MegatronTrainerState()
@@ -297,30 +297,49 @@ def _get_encoded_batch(rollout_batch):
297297
encoded_list = [template.encode(data, return_length=True) for data in rollout_batch]
298298
encoded_batch = to_device(
299299
template.data_collator(encoded_list, padding_to=get_padding_to(args)), self.device)
300-
if 'cu_seq_lens_q' in encoded_batch:
301-
cu_seq_lens_q = encoded_batch['cu_seq_lens_q']
302-
else:
303-
cu_seq_lens_q = get_packed_seq_params(encoded_batch['position_ids'])['cu_seq_lens_q']
304-
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
305300

306301
labels = encoded_batch['labels']
307302
batch_size = len(rollout_batch)
308-
max_seq_len = seq_lengths.max().item()
309-
assert self.template.padding_free
310303

311304
truncated_mask = torch.tensor([b['is_truncated'] for b in rollout_batch],
312305
dtype=torch.bool,
313306
device=self.device)
314307

315-
# completion_mask in rmpad format [1, total_tokens]
316-
completion_mask_rmpad = (labels != -100).float()
317-
completion_mask, _ = pad_logps_back_to_batch(
318-
logps_rmpad=completion_mask_rmpad,
319-
logits_to_keep=max_seq_len,
320-
batch_size=batch_size,
321-
seq_lengths=seq_lengths,
322-
pad_value=0.0)
323-
completion_mask = completion_mask.bool()
308+
if self.template.padding_free:
309+
# In padding_free mode, labels shape is [1, total_seq_len] (rmpad format)
310+
# Calculate seq_lengths from cu_seq_lens or position_ids
311+
if 'cu_seq_lens_q' in encoded_batch:
312+
cu_seq_lens_q = encoded_batch['cu_seq_lens_q']
313+
else:
314+
cu_seq_lens_q = get_packed_seq_params(encoded_batch['position_ids'])['cu_seq_lens_q']
315+
seq_lengths = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
316+
max_seq_len = seq_lengths.max().item()
317+
318+
# completion_mask in rmpad format [1, total_tokens]
319+
completion_mask_rmpad = (labels != -100).float()
320+
completion_mask, _ = pad_logps_back_to_batch(
321+
logps_rmpad=completion_mask_rmpad,
322+
logits_to_keep=max_seq_len,
323+
batch_size=batch_size,
324+
seq_lengths=seq_lengths,
325+
pad_value=0.0)
326+
completion_mask = completion_mask.bool()
327+
else:
328+
# In non-padding_free mode, labels shape is [batch_size, seq_len] (batch format)
329+
# Calculate seq_lengths from attention_mask
330+
attention_mask = encoded_batch.get('attention_mask')
331+
if attention_mask is not None:
332+
# attention_mask shape: [batch_size, seq_len] or [batch_size, 1, 1, seq_len]
333+
if attention_mask.dim() == 4:
334+
attention_mask = attention_mask[:, 0, 0, :]
335+
seq_lengths = attention_mask.sum(dim=-1).to(torch.int64)
336+
else:
337+
# Fallback: assume full sequence length for each sample
338+
seq_lengths = torch.full((batch_size, ), labels.shape[-1], dtype=torch.int64, device=self.device)
339+
max_seq_len = labels.shape[-1]
340+
341+
# completion_mask is already [batch_size, seq_len] in non-padding_free mode
342+
completion_mask = (labels != -100)
324343

325344
encoded_batch.update({
326345
'completion_mask': completion_mask, # [batch_size, max_seq_len]
@@ -400,10 +419,10 @@ def _get_encoded_batch(rollout_batch):
400419

401420
if self.loss_type in ['cispo', 'dapo']:
402421
# Calculate num_items_in_batch
403-
# Count tokens from all mini_batch_data (this includes gathered data from rollout_group)
404-
total_token_count = sum(batch_data['seq_lengths'].sum().item() if self.template.
405-
padding_free else batch_data['completion_mask'].sum().item()
406-
for batch_data in mini_batch_data)
422+
# Count completion tokens from all mini_batch_data (this includes gathered data from rollout_group)
423+
# Use completion_mask.sum() for both padding_free and non-padding_free modes
424+
# since we want the count of actual completion tokens, not sequence lengths
425+
total_token_count = sum(batch_data['completion_mask'].sum().item() for batch_data in mini_batch_data)
407426

408427
# All-reduce across all ranks
409428
total_token_count_tensor = torch.tensor(total_token_count, dtype=torch.int, device=self.device)
@@ -873,22 +892,34 @@ def _maybe_compute_logps(self, batch: Dict[str, Any]) -> Dict[str, Any]:
873892
with torch.no_grad(), self.null_ref_context() as ref_models:
874893
assert len(ref_models) == 1, 'GRPO currently does not support VPP.'
875894
ref_model = ref_models[0]
876-
ref_per_token_logps_rmpad = self.model_forward(
895+
ref_per_token_logps_raw = self.model_forward(
877896
ref_model, iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
878-
ref_per_token_logps, _ = pad_logps_back_to_batch(
879-
logps_rmpad=ref_per_token_logps_rmpad,
880-
logits_to_keep=max_seq_len,
881-
batch_size=batch_size,
882-
seq_lengths=seq_lengths)
897+
if self.template.padding_free:
898+
# In padding_free mode, logps are in rmpad format [1, total_tokens]
899+
# Pad to batch format [batch_size, max_seq_len]
900+
ref_per_token_logps, _ = pad_logps_back_to_batch(
901+
logps_rmpad=ref_per_token_logps_raw,
902+
logits_to_keep=max_seq_len,
903+
batch_size=batch_size,
904+
seq_lengths=seq_lengths)
905+
else:
906+
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
907+
ref_per_token_logps = ref_per_token_logps_raw
883908
batch['ref_per_token_logps'] = ref_per_token_logps
884909

885-
old_per_token_logps_rmpad = self.model_forward(
910+
old_per_token_logps_raw = self.model_forward(
886911
self.unwrapped_models[0], iter([deepcopy(inputs)]), no_grad=True, per_token=True)['logps']
887-
old_per_token_logps, _ = pad_logps_back_to_batch(
888-
logps_rmpad=old_per_token_logps_rmpad,
889-
logits_to_keep=max_seq_len,
890-
batch_size=batch_size,
891-
seq_lengths=seq_lengths)
912+
if self.template.padding_free:
913+
# In padding_free mode, logps are in rmpad format [1, total_tokens]
914+
# Pad to batch format [batch_size, max_seq_len]
915+
old_per_token_logps, _ = pad_logps_back_to_batch(
916+
logps_rmpad=old_per_token_logps_raw,
917+
logits_to_keep=max_seq_len,
918+
batch_size=batch_size,
919+
seq_lengths=seq_lengths)
920+
else:
921+
# In non-padding_free mode, logps are already in batch format [batch_size, seq_len]
922+
old_per_token_logps = old_per_token_logps_raw
892923
batch['old_per_token_logps'] = old_per_token_logps
893924

894925
return batch
@@ -995,29 +1026,36 @@ def forward_step(self, data_iterator, model):
9951026

9961027
@profiling_decorator
9971028
def loss_func(self, output_tensor: torch.Tensor, data: Dict[str, Any]):
1029+
args = get_args()
9981030
# Get pre-padded data in batch format [batch_size, max_seq_len]
9991031
advantages = data['advantages'] # [batch_size]
10001032
labels = data['labels']
10011033
completion_mask = data['completion_mask'] # [batch_size, max_seq_len]
1002-
packed_seq_params = data['packed_seq_params']
1034+
packed_seq_params = data.get('packed_seq_params')
10031035
truncated_mask = data['truncated_mask'] # [batch_size]
10041036
seq_lengths = data['seq_lengths'] # [batch_size]
10051037
max_seq_len = completion_mask.shape[1]
10061038
micro_batch_size = self.micro_batch_size
10071039

1008-
# Use full sequence lengths directly (get_logps returns full sequences in CP mode)
1009-
lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size
1010-
+ 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size]
1011-
1012-
# get_logps with per_token=True returns rmpad format [1, total_tokens]
1013-
# Pad to batch format [batch_size, max_seq_len]
1014-
per_token_logps_rmpad = self.get_logps(
1015-
output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True)
1016-
per_token_logps, _ = pad_logps_back_to_batch(
1017-
logps_rmpad=per_token_logps_rmpad,
1018-
logits_to_keep=max_seq_len,
1019-
batch_size=micro_batch_size,
1020-
seq_lengths=seq_lengths)
1040+
if args.padding_free:
1041+
# Use full sequence lengths directly (get_logps returns full sequences in CP mode)
1042+
lengths = packed_seq_params.cu_seqlens_q[1:micro_batch_size
1043+
+ 1] - packed_seq_params.cu_seqlens_q[:micro_batch_size]
1044+
1045+
# get_logps with per_token=True returns rmpad format [1, total_tokens]
1046+
# Pad to batch format [batch_size, max_seq_len]
1047+
per_token_logps_rmpad = self.get_logps(
1048+
output_tensor, labels, packed_seq_params, packed_seq_params.num_samples, per_token=True)
1049+
per_token_logps, _ = pad_logps_back_to_batch(
1050+
logps_rmpad=per_token_logps_rmpad,
1051+
logits_to_keep=max_seq_len,
1052+
batch_size=micro_batch_size,
1053+
seq_lengths=seq_lengths)
1054+
else:
1055+
# In non-padding_free mode, get_logps with per_token=True returns [batch_size, seq_len]
1056+
# No need to pad, already in batch format
1057+
lengths = seq_lengths
1058+
per_token_logps = self.get_logps(output_tensor, labels, packed_seq_params, micro_batch_size, per_token=True)
10211059

10221060
# Get pre-padded ref/old/rollout logps from data
10231061
ref_per_token_logps = data.get('ref_per_token_logps') # [batch_size, max_seq_len] or None
@@ -1245,13 +1283,19 @@ def model_forward(self, model, data_iterator, no_grad=True, per_token=False):
12451283
with self.stimer(bdata=True):
12461284
data = self.get_batch(data_iterator)
12471285
data.pop('loss_scale', None)
1286+
input_ids = data.get('input_ids')
12481287
labels = data.get('labels')
12491288
context = torch.no_grad() if no_grad else nullcontext()
12501289
with context:
12511290
output_tensor = forward_step_helper(model, data)
1252-
packed_seq_params = data['packed_seq_params']
1291+
# packed_seq_params only exists in padding_free mode
1292+
packed_seq_params = data.get('packed_seq_params')
1293+
if packed_seq_params is not None:
1294+
num_samples = packed_seq_params.num_samples
1295+
else:
1296+
num_samples = input_ids.shape[0] if input_ids is not None else labels.shape[0]
12531297
data['logps'] = None if labels is None else self.get_logps(
1254-
output_tensor, labels, data['packed_seq_params'], packed_seq_params.num_samples, per_token=per_token)
1298+
output_tensor, labels, packed_seq_params, num_samples, per_token=per_token)
12551299
return data
12561300

12571301
def inputs2requests(self, inputs: Union[DataType, List[RolloutInferRequest]]) -> List[RolloutInferRequest]:

swift/megatron/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ def forward_step_helper(model, inputs, dtype=None):
287287
args = get_args()
288288
if mpu.is_pipeline_first_stage():
289289
micro_batch_size = 1 # use qkv_format 'thd'
290+
if not args.padding_free:
291+
micro_batch_size = args.micro_batch_size
290292
seq_length = inputs['position_ids'].shape[-1]
291293
if args.sequence_parallel:
292294
seq_length //= mpu.get_tensor_model_parallel_world_size()

0 commit comments

Comments
 (0)