Skip to content

Commit ce11425

Browse files
committed
[Refactor] refactor packing in RL train controller and train worker
1 parent 7c8f82c commit ce11425

File tree

3 files changed

+533
-197
lines changed

3 files changed

+533
-197
lines changed

xtuner/v1/rl/base/controller.py

Lines changed: 221 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
import math
2-
from typing import Literal, TypedDict
2+
from typing import Literal, TypedDict, cast
33

4+
import numpy as np
45
import ray
56
import torch
67
from ray.actor import ActorProxy
78

89
from xtuner.v1.data_proto.sequence_context import SequenceContext
910
from xtuner.v1.model.compose.base import BaseComposeConfig
11+
from xtuner.v1.rl.utils import get_seqlen_balanced_partitions
1012
from xtuner.v1.train.trainer import LoadCheckpointConfig
11-
from xtuner.v1.utils import ray_method
13+
from xtuner.v1.utils import get_logger, ray_method
1214

1315
from .worker import TrainingWorker
1416

@@ -23,6 +25,9 @@ class ColateItem(TypedDict):
2325
class RawTrainingController:
2426
def __init__(self, workers: list[TrainingWorker]) -> None:
2527
self.workers = workers
28+
self.model_cfg = ray.get(self.workers[0].get_model_cfg.remote())
29+
self.worker_cfg = ray.get(self.workers[0].get_worker_cfg.remote())
30+
self.data_replicate_size = ray.get(self.workers[0].get_data_replicate_size.remote())
2631

2732
# TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
2833
def _get_pack_infos(self, dataset, num_tokens, target, random=None):
@@ -164,95 +169,236 @@ def _grouped_by_max_length(self, packed_data_batches):
164169
# 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。
165170
return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True)
166171

167-
@ray_method
168-
def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int):
169-
has_rollout_routed_experts = False
170-
language_cfg = None
171-
if data_batches[0]["seq_ctx"].rollout_routed_experts is not None:
172-
model_cfg = ray.get(self.workers[0].get_model_cfg.remote()) # type: ignore[attr-defined]
173-
has_rollout_routed_experts = True
174-
language_cfg = model_cfg
175-
if isinstance(model_cfg, BaseComposeConfig):
176-
language_cfg = model_cfg.text_config
177-
178-
packed_data_batches = self._packing(data_batches, pack_max_length, language_cfg)
179-
# packed_data_batches = self._grouped_by_max_length(packed_data_batches)
172+
def _balance_split_batch(self, data_batches, partition_size):
173+
"""Reorder the data on single controller such that each dp rank gets
174+
similar total tokens."""
175+
global_seqlen_lst = [data["seq_ctx"].input_ids.numel() for data in data_batches]
176+
global_partition_lst = get_seqlen_balanced_partitions(
177+
global_seqlen_lst, k_partitions=partition_size, equal_size=True
178+
)
179+
balanced_batches = []
180+
tokens_in_partition = []
181+
for partition in global_partition_lst:
182+
partition_batch = [data_batches[i] for i in partition]
183+
tokens_in_partition.append(sum(data["seq_ctx"].input_ids.numel() for data in partition_batch))
184+
balanced_batches.append(partition_batch)
185+
get_logger().info(f"Balanced split into {partition_size} partitions with tokens: {tokens_in_partition}")
186+
return balanced_batches
187+
188+
def _create_padding_sample(
189+
self,
190+
pad_len: int,
191+
pack_max_length: int,
192+
is_qwen3_vl: bool = False,
193+
has_rollout_routed_experts: bool = False,
194+
has_rollout_logprobs: bool = True,
195+
n_routed_experts: int = 0,
196+
split_size: int = 1024,
197+
):
198+
# padding input_ids
199+
pad_tokens = tuple(
200+
torch.zeros(1, split_size, dtype=torch.long, device="cpu") for _ in range(pad_len // split_size)
201+
)
202+
if pad_len % split_size > 0:
203+
pad_tokens = pad_tokens + (torch.zeros(1, pad_len % split_size, dtype=torch.long, device="cpu"),)
204+
pad_tokens = cast(tuple[torch.LongTensor, ...], pad_tokens)
205+
pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu")
206+
pad_seq_ctx.num_padding = pad_len
207+
208+
# padding mm positions_ids
209+
if is_qwen3_vl:
210+
_position_ids_list = []
211+
for pad_token in pad_tokens:
212+
_position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1)
213+
_position_ids_list.append(_position_ids)
214+
position_ids = torch.cat(_position_ids_list, dim=-1)
215+
position_ids = cast(torch.LongTensor, position_ids)
216+
pad_seq_ctx.position_ids = position_ids
217+
218+
# padding rollout routed experts
219+
if has_rollout_routed_experts:
220+
if pad_len == pack_max_length:
221+
pad_rand_index = torch.randint(
222+
low=0, high=1, size=(1, 1, 1)
223+
) # add dummy data, true data will be initialized in train worker.fit
224+
else:
225+
pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1))
226+
pad_seq_ctx.rollout_routed_experts = pad_rand_index
180227

181-
# TODO(hha): 这个逻辑不够通用,和模型绑定了
182-
is_qwen3_vl = False
183-
if len(packed_data_batches[0]["seq_ctx"].position_ids.shape) == 3:
184-
is_qwen3_vl = True
228+
pad_labels = torch.full((1, pad_len), -100, dtype=torch.long, device="cpu")
185229

186-
# todo: support round up
187-
num_packed_data_batches = len(packed_data_batches)
188-
data_replicate_size = ray.get(self.workers[0].get_data_replicate_size.remote()) # type: ignore[attr-defined]
189-
dp_size = len(self.workers) // data_replicate_size
190-
pad_num = math.ceil(num_packed_data_batches / dp_size) * dp_size - num_packed_data_batches
191-
if pad_num > 0:
192-
# Reduce the attn calculation time by using multiple short sequence packs
193-
assert data_batches[0]["seq_ctx"].input_ids is not None
194-
pad_tokens = tuple(
195-
torch.zeros(1, 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu")
196-
for _ in range(pack_max_length // 1024)
197-
)
198-
if pack_max_length % 1024 > 0:
199-
assert data_batches[0]["seq_ctx"].input_ids is not None
200-
pad_tokens = pad_tokens + (
201-
torch.zeros(
202-
1, pack_max_length % 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu"
203-
),
204-
)
205-
pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") # type: ignore
206-
pad_seq_ctx.num_padding = pack_max_length
207-
if is_qwen3_vl:
208-
_position_ids_list = []
209-
for pad_token in pad_tokens:
210-
_position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1)
211-
_position_ids_list.append(_position_ids)
212-
pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1) # type: ignore
213-
214-
pad_shifted_labels = torch.full(
230+
if pad_len == pack_max_length:
231+
pad_advantage_tensor = torch.full(
215232
(1, pack_max_length),
216233
-100,
217-
dtype=packed_data_batches[0]["shifted_labels"].dtype,
234+
dtype=torch.float32,
218235
device="cpu",
219236
)
220-
pad_advantages = torch.full(
221-
(1, pack_max_length),
222-
-100,
223-
dtype=packed_data_batches[0]["advantages"].dtype,
224-
device="cpu",
237+
else:
238+
pad_advantage_array = [-100] * math.ceil(pad_len / split_size)
239+
pad_rollout_logprobs = (
240+
torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") if has_rollout_logprobs else None
241+
)
242+
243+
return {
244+
"seq_ctx": pad_seq_ctx,
245+
"shifted_labels": pad_labels,
246+
"advantages": pad_advantage_tensor if pad_len == pack_max_length else pad_advantage_array,
247+
"rollout_logprobs": pad_rollout_logprobs,
248+
}
249+
250+
def _pack(self, mini_batch, pack_max_length):
251+
seqlen_list = []
252+
for data in mini_batch:
253+
assert data["seq_ctx"].input_ids.numel() <= pack_max_length, (
254+
f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {pack_max_length}"
225255
)
256+
seqlen_list.append(data["seq_ctx"].input_ids.numel())
257+
total_length = sum(seqlen_list)
226258

227-
if has_rollout_routed_experts:
228-
pad_rand_index = torch.randint(
229-
low=0,
230-
high=1,
231-
size=(1, 1, 1), # add dummy data, true data will be initialized in train worker.fit
232-
)
233-
pad_seq_ctx.rollout_routed_experts = pad_rand_index
259+
if total_length <= pack_max_length:
260+
return [mini_batch] # No packing needed
234261

235-
pad_rollout_logprobs = None
236-
if "rollout_logprobs" in packed_data_batches[0] and packed_data_batches[0]["rollout_logprobs"] is not None:
237-
pad_rollout_logprobs = torch.zeros(
238-
1, pack_max_length, dtype=packed_data_batches[0]["rollout_logprobs"].dtype, device="cpu"
239-
)
240-
pad_data = {
241-
"seq_ctx": pad_seq_ctx,
242-
"shifted_labels": pad_shifted_labels,
243-
"advantages": pad_advantages,
244-
"rollout_logprobs": pad_rollout_logprobs,
262+
num_packs = math.ceil(total_length / pack_max_length)
263+
partitions_indices = get_seqlen_balanced_partitions(
264+
seqlen_list=seqlen_list, k_partitions=num_packs, equal_size=False
265+
)
266+
267+
packed_mini_batches = []
268+
for partition in partitions_indices:
269+
packed_batch = [mini_batch[i] for i in partition]
270+
packed_mini_batches.append(packed_batch)
271+
return packed_mini_batches
272+
273+
def _get_data_batches_properties(self, data_batches: list[ColateItem]):
274+
"""Extract properties from the first element of data_batches."""
275+
if not data_batches:
276+
return {
277+
"is_qwen3_vl": False,
278+
"has_rollout_routed_experts": False,
279+
"has_rollout_logprobs": False,
280+
"n_routed_experts": None,
245281
}
246-
pad_data_samples = [pad_data for _ in range(pad_num)]
247-
packed_data_batches = packed_data_batches + pad_data_samples
248282

249-
print(f"len(packed_data_batches): {len(packed_data_batches)}")
283+
first_item = data_batches[0]
284+
seq_ctx = first_item["seq_ctx"]
285+
286+
is_qwen3_vl = seq_ctx.position_ids is not None and len(seq_ctx.position_ids.shape) == 3
287+
has_rollout_logprobs = "rollout_logprobs" in first_item and first_item["rollout_logprobs"] is not None
288+
has_rollout_routed_experts = seq_ctx.rollout_routed_experts is not None
289+
290+
model_cfg = ray.get(self.workers[0].get_model_cfg.remote()) # type: ignore[attr-defined]
291+
language_cfg = None
292+
if has_rollout_routed_experts:
293+
language_cfg = model_cfg
294+
if isinstance(model_cfg, BaseComposeConfig):
295+
language_cfg = model_cfg.text_config
296+
297+
return {
298+
"is_qwen3_vl": is_qwen3_vl,
299+
"has_rollout_routed_experts": has_rollout_routed_experts,
300+
"has_rollout_logprobs": has_rollout_logprobs,
301+
"n_routed_experts": language_cfg.n_routed_experts if language_cfg is not None else None,
302+
}
303+
304+
@ray_method
305+
def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int, enable_balance: bool = True):
306+
batch_props = self._get_data_batches_properties(data_batches)
307+
is_qwen3_vl = batch_props["is_qwen3_vl"]
308+
has_rollout_routed_experts = batch_props["has_rollout_routed_experts"]
309+
has_rollout_logprobs = batch_props["has_rollout_logprobs"]
310+
n_routed_experts = batch_props["n_routed_experts"]
311+
312+
world_size = len(self.workers)
313+
dp_size = world_size // self.data_replicate_size
314+
assert world_size % self.data_replicate_size == 0, "world_size must be divisible by data_replicate_size"
315+
optimizer_steps = self.worker_cfg.optimizer_steps
316+
317+
if enable_balance:
318+
batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)
319+
else:
320+
batches_per_dp_group = np.array_split(data_batches, dp_size)
321+
322+
packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)]
323+
max_packs_per_card = [0] * optimizer_steps
324+
325+
for dp_rank, dp_worker_data_batches in enumerate(batches_per_dp_group):
326+
# 每个worker 内部按照optimizer_steps将token均分
327+
mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)
328+
329+
for step_idx, step_mini_batch in enumerate(mini_batch_for_steps):
330+
# pack
331+
pack_mini_batch = self._pack(step_mini_batch, pack_max_length)
332+
if len(pack_mini_batch) > max_packs_per_card[step_idx]:
333+
max_packs_per_card[step_idx] = len(pack_mini_batch)
334+
335+
for pack in pack_mini_batch:
336+
seq_ctx_list = [item["seq_ctx"] for item in pack]
337+
label_list = [item["shifted_labels"] for item in pack]
338+
advantage_list = [torch.tensor([item["advantage"]]).float().unsqueeze(0) for item in pack]
339+
rollout_logprobs_list = [
340+
item["rollout_logprobs"] if has_rollout_logprobs else None for item in pack
341+
]
342+
padding_len = pack_max_length - sum([item["seq_ctx"].input_ids.numel() for item in pack])
343+
if padding_len > 0:
344+
padding_sample = self._create_padding_sample(
345+
padding_len,
346+
pack_max_length,
347+
is_qwen3_vl=is_qwen3_vl,
348+
has_rollout_routed_experts=has_rollout_routed_experts,
349+
has_rollout_logprobs=has_rollout_logprobs,
350+
n_routed_experts=n_routed_experts,
351+
)
352+
seq_ctx_list.append(padding_sample["seq_ctx"])
353+
label_list.append(padding_sample["shifted_labels"])
354+
advantage_list.extend(padding_sample["advantages"])
355+
rollout_logprobs_list.append(padding_sample["rollout_logprobs"])
356+
357+
packed_seq_ctx = SequenceContext.pack(seq_ctx_list)
358+
paced_shifted_labels = torch.cat(label_list, dim=1)
359+
packed_advantages = torch.tensor(advantage_list).float().unsqueeze(0)
360+
cu_seq_lens_q = packed_seq_ctx.cu_seq_lens_q
361+
packed_num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
362+
packed_advantages = torch.repeat_interleave(packed_advantages, packed_num_tokens, dim=1)
363+
if has_rollout_logprobs:
364+
cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list]
365+
packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1)
366+
else:
367+
packed_rollout_logprobs = None
368+
packed_data_batches[dp_rank][step_idx].append(
369+
{
370+
"seq_ctx": packed_seq_ctx,
371+
"shifted_labels": paced_shifted_labels,
372+
"advantages": packed_advantages,
373+
"rollout_logprobs": packed_rollout_logprobs,
374+
}
375+
)
376+
377+
get_logger().info(f"Gradient accumulation steps: {max_packs_per_card}")
378+
# padding for each worker to have same number of packs
379+
for dp_rank in range(dp_size):
380+
for step_idx in range(optimizer_steps):
381+
max_packs = max_packs_per_card[step_idx]
382+
num_current_packs = len(packed_data_batches[dp_rank][step_idx])
383+
num_padding_packs = max_packs - num_current_packs
384+
385+
if num_padding_packs > 0:
386+
padding_sample = self._create_padding_sample(
387+
pack_max_length,
388+
pack_max_length,
389+
is_qwen3_vl=is_qwen3_vl,
390+
has_rollout_routed_experts=has_rollout_routed_experts,
391+
has_rollout_logprobs=has_rollout_logprobs,
392+
n_routed_experts=n_routed_experts,
393+
)
394+
padding_samples = [padding_sample for _ in range(num_padding_packs)]
395+
packed_data_batches[dp_rank][step_idx].extend(padding_samples)
250396

251397
handles = []
252398
for worker_idx, worker in enumerate(self.workers):
253399
handles.append(
254400
worker.fit.remote( # type: ignore[attr-defined]
255-
data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size],
401+
data_batches=packed_data_batches[worker_idx // self.data_replicate_size],
256402
rollout_idx=rollout_idx,
257403
)
258404
)

0 commit comments

Comments
 (0)