Skip to content

Commit fee1061

Browse files
committed
update
2 parents f4dae57 + f3ce0a6 commit fee1061

2 files changed

Lines changed: 66 additions & 30 deletions

File tree

swift/megatron/trainers/grpo_trainer.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def __init__(self, args: MegatronRLHFArguments, template: Template, **kwargs):
5454
self._prepare_metrics()
5555
self._init_grpo_params()
5656
self._init_rollout_engine()
57+
self._init_rollout_engine()
5758
self._prepare_rewards()
58-
self._prepare_scheduler() # TODO
59+
self._prepare_scheduler()
5960
# Initialize trainer state for reward functions to access training progress
6061
# Will be updated with actual values from Megatron args during training
6162
self.state = MegatronTrainerState()
@@ -1017,6 +1018,16 @@ def forward_step(self, data_iterator, model):
10171018
# train_batch_size
10181019
# return: output_tensor, loss_func
10191020
data = self.get_batch(data_iterator)
1021+
data = next(data_iterator)
1022+
advantages = data.pop('advantages')
1023+
truncated_mask = data.pop('truncated_mask')
1024+
seq_lengths = data.pop('seq_lengths')
1025+
data = self._prepare_batch(data)
1026+
data.update({
1027+
'advantages': advantages,
1028+
'truncated_mask': truncated_mask,
1029+
'seq_lengths': seq_lengths,
1030+
})
10201031
data.pop('loss_scale', None)
10211032
inputs = self._prepare_model_inputs(data)
10221033

swift/megatron/trainers/rlhf_mixin.py

Lines changed: 54 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,12 @@ def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
8686
Works for both logps (float) and masks (bool/int).
8787
8888
Args:
89-
tensor: [1, packed_len/cp_size] - CP-split tensor (any dtype)
90-
packed_seq_params: PackedSeqParams object
89+
tensor: [1, packed_len/cp_size] in padding_free mode, or [batch_size, seq_len/cp_size] otherwise
90+
packed_seq_params: PackedSeqParams object (None in non-padding_free mode)
9191
num_samples: Number of samples in the batch
9292
9393
Returns:
94-
output_full: [1, packed_len] - Full sequence tensor
94+
output_full: [1, packed_len] in padding_free mode, or [batch_size, seq_len] otherwise
9595
"""
9696
args = get_args()
9797
cp_size = args.context_parallel_size
@@ -102,36 +102,61 @@ def _postprocess_packed_tensor_cp(self, tensor, packed_seq_params, num_samples):
102102
torch.distributed.all_gather(output_list, tensor.contiguous(), group=mpu.get_context_parallel_group())
103103
output_list[cp_rank] = tensor
104104

105-
# Reconstruct full sequence
106-
# Shape: [1, packed_len/cp_size] -> [1, packed_len]
107-
cu_seqlens_full = packed_seq_params.cu_seqlens_q
108-
cu_seqlens_cp = cu_seqlens_full // cp_size
105+
if packed_seq_params is not None:
106+
# padding_free mode: [1, packed_len/cp_size] -> [1, packed_len]
107+
cu_seqlens_full = packed_seq_params.cu_seqlens_q
108+
cu_seqlens_cp = cu_seqlens_full // cp_size
109109

110-
# Calculate total packed length
111-
total_packed_len = cu_seqlens_full[num_samples].item()
112-
output_full = tensor.new_zeros(1, total_packed_len)
110+
# Calculate total packed length
111+
total_packed_len = cu_seqlens_full[num_samples].item()
112+
output_full = tensor.new_zeros(1, total_packed_len)
113113

114-
# Reconstruct each sequence
115-
for i in range(num_samples):
116-
start_full = cu_seqlens_full[i].item()
117-
end_full = cu_seqlens_full[i + 1].item()
118-
seq_len = end_full - start_full
114+
# Reconstruct each sequence
115+
for i in range(num_samples):
116+
start_full = cu_seqlens_full[i].item()
117+
end_full = cu_seqlens_full[i + 1].item()
118+
seq_len = end_full - start_full
119+
120+
# Length of each chunk after CP split
121+
chunk_len = seq_len // cp_size
122+
half_chunk = chunk_len // 2
123+
124+
# Concatenate from each CP rank's output (load-balanced split)
125+
for j in range(cp_size):
126+
o = output_list[j][0]
127+
start_cp = cu_seqlens_cp[i].item()
128+
129+
# Get two half chunks (CP's load-balanced split)
130+
o0 = o[start_cp:start_cp + half_chunk]
131+
o1 = o[start_cp + half_chunk:start_cp + chunk_len]
132+
133+
# Place back to full sequence
134+
output_full[0, start_full + j * half_chunk:start_full + (j + 1) * half_chunk] = o0
135+
output_full[0, end_full - (j + 1) * half_chunk:end_full - j * half_chunk] = o1
136+
else:
137+
# non-padding_free mode: [batch_size, seq_len/cp_size] -> [batch_size, seq_len]
138+
# Each CP rank has chunks split with load-balanced pattern (2*cp_size chunks)
139+
batch_size = tensor.shape[0]
140+
seq_len_per_cp = tensor.shape[1]
141+
full_seq_len = seq_len_per_cp * cp_size
119142

120-
# Length of each chunk after CP split
121-
chunk_len = seq_len // cp_size
122-
half_chunk = chunk_len // 2
143+
output_full = tensor.new_zeros(batch_size, full_seq_len)
123144

124-
# Concatenate from each CP rank's output (load-balanced split)
125-
for j in range(cp_size):
126-
o = output_list[j][0]
127-
start_cp = cu_seqlens_cp[i].item()
145+
# Each CP rank j holds chunks j and (2*cp_size - j - 1) from the original 2*cp_size split
146+
# Reconstruct the full sequence by placing chunks back in correct positions
147+
chunk_len = full_seq_len // (2 * cp_size)
128148

129-
# Get two half chunks (CP's load-balanced split)
130-
o0 = o[start_cp:start_cp + half_chunk]
131-
o1 = o[start_cp + half_chunk:start_cp + chunk_len]
132-
133-
# Place back to full sequence
134-
output_full[0, start_full + j * half_chunk:start_full + (j + 1) * half_chunk] = o0
135-
output_full[0, end_full - (j + 1) * half_chunk:end_full - j * half_chunk] = o1
149+
for j in range(cp_size):
150+
o = output_list[j] # [batch_size, seq_len_per_cp]
151+
# This rank holds 2 chunks: chunk j and chunk (2*cp_size - j - 1)
152+
half_len = seq_len_per_cp // 2
153+
o0 = o[:, :half_len] # First half -> chunk j
154+
o1 = o[:, half_len:] # Second half -> chunk (2*cp_size - j - 1)
155+
156+
# Place chunk j at position j * chunk_len
157+
output_full[:, j * chunk_len:(j + 1) * chunk_len] = o0
158+
# Place chunk (2*cp_size - j - 1) at position (2*cp_size - j - 1) * chunk_len
159+
reverse_idx = 2 * cp_size - j - 1
160+
output_full[:, reverse_idx * chunk_len:(reverse_idx + 1) * chunk_len] = o1
136161

137162
return output_full

0 commit comments

Comments
 (0)