Skip to content

Commit 265707b

Browse files
committed
add more clear type annotations in TrainController
1 parent 02b2910 commit 265707b

File tree

3 files changed

+155
-132
lines changed

3 files changed

+155
-132
lines changed

xtuner/v1/rl/base/controller.py

Lines changed: 133 additions & 120 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
import math
22
import os
33
import random
4-
from typing import Literal, TypedDict, cast
4+
from pathlib import Path
5+
from typing import Literal, cast
56

67
import numpy as np
78
import ray
@@ -14,17 +15,10 @@
1415
from xtuner.v1.train.trainer import LoadCheckpointConfig
1516
from xtuner.v1.utils import get_logger, ray_method
1617

17-
from .worker import TrainingWorker, WorkerLogItem
18-
1918

2019
TRAIN_RAY_GET_TIMEOUT = os.getenv("XTUNER_TRAIN_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours
2120

22-
23-
class ColateItem(TypedDict):
24-
seq_ctx: SequenceContext
25-
shifted_labels: torch.Tensor
26-
advantage: float
27-
rollout_logprobs: torch.Tensor | None
21+
from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem
2822

2923

3024
class RawTrainingController:
@@ -36,6 +30,17 @@ def __init__(self, workers: list[TrainingWorker]) -> None:
3630
self.workers[0].get_data_replicate_size.remote(),
3731
]
3832
self.model_cfg, self.worker_cfg, self.data_replicate_size = ray.get(refs)
33+
log_dir = self.worker_cfg.log_dir
34+
self.log_dir = None
35+
if log_dir is not None:
36+
self.log_dir = Path(log_dir) if isinstance(log_dir, str) else log_dir
37+
self.logger = get_logger(log_dir=self.log_dir, tag="TrainingController")
38+
else:
39+
self.logger = get_logger()
40+
self.is_qwen3_vl = False
41+
self.has_rollout_routed_experts = False
42+
self.has_rollout_logprobs = False
43+
self.n_routed_experts = None
3944

4045
# TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
4146
def _get_pack_infos(self, dataset, num_tokens, target, random=None):
@@ -100,7 +105,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
100105
pad_len = pack_max_length - total_len
101106
seq_ctx_list = [data_batches[i]["seq_ctx"] for i in indices]
102107
label_list = [data_batches[i]["shifted_labels"] for i in indices]
103-
advantage_list = [data_batches[i]["advantage"] for i in indices]
108+
advantage_list = [data_batches[i]["advantages"] for i in indices]
104109

105110
rollout_logprobs_list = None
106111
if "rollout_logprobs" in data_batches[0] and data_batches[0]["rollout_logprobs"] is not None:
@@ -177,10 +182,10 @@ def _grouped_by_max_length(self, packed_data_batches):
177182
# 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。
178183
return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True)
179184

180-
def _balance_split_batch(self, data_batches, partition_size):
185+
def _balance_split_batch(self, data_batches: list[WorkerInputItem], partition_size) -> list[list[WorkerInputItem]]:
181186
"""Reorder the data on single controller such that each dp rank gets
182187
similar total tokens."""
183-
global_seqlen_lst = [data["seq_ctx"].input_ids.numel() for data in data_batches]
188+
global_seqlen_lst = [data["seq_ctx"].input_ids.numel() for data in data_batches] # type: ignore[union-attr]
184189
global_partition_lst = get_seqlen_balanced_partitions(
185190
global_seqlen_lst, k_partitions=partition_size, equal_size=True
186191
)
@@ -193,16 +198,12 @@ def _balance_split_batch(self, data_batches, partition_size):
193198
get_logger().info(f"Balanced split into {partition_size} partitions with tokens: {tokens_in_partition}")
194199
return balanced_batches
195200

196-
def _create_padding_sample(
201+
def _create_padding_item(
197202
self,
198203
pad_len: int,
199204
pack_max_length: int,
200-
is_qwen3_vl: bool = False,
201-
has_rollout_routed_experts: bool = False,
202-
has_rollout_logprobs: bool = True,
203-
n_routed_experts: int | None = None,
204205
split_size: int = 1024,
205-
):
206+
) -> WorkerInputItem:
206207
# padding input_ids
207208
pad_tokens = tuple(
208209
torch.zeros(1, split_size, dtype=torch.long, device="cpu") for _ in range(pad_len // split_size)
@@ -214,7 +215,7 @@ def _create_padding_sample(
214215
pad_seq_ctx.num_padding = pad_len
215216

216217
# padding mm positions_ids
217-
if is_qwen3_vl:
218+
if self.is_qwen3_vl:
218219
_position_ids_list = []
219220
for pad_token in pad_tokens:
220221
_position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1)
@@ -224,17 +225,17 @@ def _create_padding_sample(
224225
pad_seq_ctx.position_ids = position_ids
225226

226227
# padding rollout routed experts
227-
if has_rollout_routed_experts:
228-
assert n_routed_experts, "n_routed_experts must be provided when has_rollout_routed_experts is True"
228+
if self.has_rollout_routed_experts:
229+
assert self.n_routed_experts, "n_routed_experts must be provided when has_rollout_routed_experts is True"
229230
if pad_len == pack_max_length:
230231
pad_rand_index = torch.randint(
231232
low=0, high=1, size=(1, 1, 1)
232233
) # add dummy data, true data will be initialized in train worker.fit
233234
else:
234-
pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1))
235+
pad_rand_index = torch.randint(low=0, high=self.n_routed_experts, size=(pad_len, 1, 1))
235236
pad_seq_ctx.rollout_routed_experts = pad_rand_index
236237

237-
pad_labels = torch.full((1, pad_len), -100, dtype=torch.long, device="cpu")
238+
pad_labels = cast(torch.LongTensor, torch.full((1, pad_len), -100, dtype=torch.int64, device="cpu"))
238239
pad_advantage_length = pack_max_length if pad_len == pack_max_length else math.ceil(pad_len / 1024)
239240
pad_advantage = torch.full(
240241
(1, pad_advantage_length),
@@ -243,24 +244,27 @@ def _create_padding_sample(
243244
device="cpu",
244245
)
245246
pad_rollout_logprobs = (
246-
torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") if has_rollout_logprobs else None
247+
torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") if self.has_rollout_logprobs else None
247248
)
248249

249-
return {
250+
padding_item: WorkerInputItem = {
250251
"seq_ctx": pad_seq_ctx,
251252
"shifted_labels": pad_labels,
252253
"advantages": pad_advantage,
253254
"rollout_logprobs": pad_rollout_logprobs,
254255
}
256+
return padding_item
255257

256-
def _pack(self, mini_batch, pack_max_length):
258+
def _rearrange_batch_for_pack(
259+
self, mini_batch: list[WorkerInputItem], pack_max_length: int
260+
) -> list[list[WorkerInputItem]]:
257261
assert len(mini_batch) > 0, "mini_batch should not be empty"
258262
seqlen_list = []
259263
for data in mini_batch:
260-
assert data["seq_ctx"].input_ids.numel() <= pack_max_length, (
261-
f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {pack_max_length}"
264+
assert data["seq_ctx"].input_ids.numel() <= pack_max_length, ( # type: ignore[union-attr]
265+
f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {pack_max_length}" # type: ignore[union-attr]
262266
)
263-
seqlen_list.append(data["seq_ctx"].input_ids.numel())
267+
seqlen_list.append(data["seq_ctx"].input_ids.numel()) # type: ignore[union-attr]
264268
total_length = sum(seqlen_list)
265269

266270
if total_length <= pack_max_length:
@@ -277,15 +281,10 @@ def _pack(self, mini_batch, pack_max_length):
277281
packed_mini_batches.append(packed_batch)
278282
return packed_mini_batches
279283

280-
def _get_data_batches_properties(self, data_batches: list[ColateItem]):
284+
def _set_data_batches_properties(self, data_batches: list[WorkerInputItem]):
281285
"""Extract properties from the first element of data_batches."""
282286
if not data_batches:
283-
return {
284-
"is_qwen3_vl": False,
285-
"has_rollout_routed_experts": False,
286-
"has_rollout_logprobs": False,
287-
"n_routed_experts": None,
288-
}
287+
return
289288

290289
first_item = data_batches[0]
291290
seq_ctx = first_item["seq_ctx"]
@@ -300,114 +299,128 @@ def _get_data_batches_properties(self, data_batches: list[ColateItem]):
300299
if isinstance(self.model_cfg, BaseComposeConfig):
301300
language_cfg = self.model_cfg.text_config
302301

303-
return {
304-
"is_qwen3_vl": is_qwen3_vl,
305-
"has_rollout_routed_experts": has_rollout_routed_experts,
306-
"has_rollout_logprobs": has_rollout_logprobs,
307-
"n_routed_experts": language_cfg.n_routed_experts if language_cfg is not None else None,
302+
self.is_qwen3_vl = is_qwen3_vl
303+
self.has_rollout_routed_experts = has_rollout_routed_experts
304+
self.has_rollout_logprobs = has_rollout_logprobs
305+
self.n_routed_experts = language_cfg.n_routed_experts if language_cfg is not None else None
306+
307+
def _pad_and_pack_batches(self, batch4pack: list[WorkerInputItem], pack_max_length: int) -> WorkerInputItem:
308+
seq_ctx_list = [item["seq_ctx"] for item in batch4pack]
309+
label_list = [item["shifted_labels"] for item in batch4pack]
310+
advantage_list = [torch.tensor([item["advantages"]]).float().unsqueeze(0) for item in batch4pack]
311+
rollout_logprobs_list = [
312+
item["rollout_logprobs"] if self.has_rollout_logprobs else None for item in batch4pack
313+
]
314+
cur_length = 0
315+
for item in batch4pack:
316+
cur_length += item["seq_ctx"].input_ids.numel() # type: ignore[union-attr]
317+
padding_len = pack_max_length - cur_length
318+
319+
if padding_len > 0:
320+
padding_item = self._create_padding_item(padding_len, pack_max_length)
321+
seq_ctx_list.append(padding_item["seq_ctx"])
322+
label_list.append(padding_item["shifted_labels"])
323+
advantage_list.append(padding_item["advantages"])
324+
rollout_logprobs_list.append(padding_item["rollout_logprobs"])
325+
326+
packed_seq_ctx = SequenceContext.pack(seq_ctx_list)
327+
packed_shifted_labels = torch.cat(label_list, dim=1) # type: ignore[arg-type]
328+
packed_shifted_labels = cast(torch.LongTensor, packed_shifted_labels)
329+
cu_seq_lens_q = packed_seq_ctx.cu_seq_lens_q
330+
packed_num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
331+
packed_advantages = torch.cat(advantage_list, dim=1)
332+
packed_advantages = torch.repeat_interleave(packed_advantages, packed_num_tokens, dim=1)
333+
if self.has_rollout_logprobs:
334+
cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list]
335+
packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1)
336+
else:
337+
packed_rollout_logprobs = None
338+
339+
optimizer_step_packs: WorkerInputItem = {
340+
"seq_ctx": packed_seq_ctx,
341+
"shifted_labels": packed_shifted_labels,
342+
"advantages": packed_advantages,
343+
"rollout_logprobs": packed_rollout_logprobs,
308344
}
345+
return optimizer_step_packs
346+
347+
def _pad_to_max_packs_across_workes(
348+
self,
349+
packed_data_batches: list[list[list[WorkerInputItem]]],
350+
step_idx: int,
351+
max_packs: int,
352+
pack_max_length: int,
353+
):
354+
for dp_rank in range(len(packed_data_batches)):
355+
num_current_packs = len(packed_data_batches[dp_rank][step_idx])
356+
num_padding_packs = max_packs - num_current_packs
357+
358+
if num_padding_packs > 0:
359+
padding_item = self._create_padding_item(pack_max_length, pack_max_length)
360+
padding_items = [padding_item for _ in range(num_padding_packs)]
361+
packed_data_batches[dp_rank][step_idx].extend(padding_items)
309362

310363
@ray_method
311364
def fit(
312-
self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int, enable_dp_balance: bool = True
313-
):
314-
batch_props = self._get_data_batches_properties(data_batches)
315-
is_qwen3_vl = batch_props["is_qwen3_vl"]
316-
has_rollout_routed_experts = batch_props["has_rollout_routed_experts"]
317-
has_rollout_logprobs = batch_props["has_rollout_logprobs"]
318-
n_routed_experts = batch_props["n_routed_experts"]
365+
self,
366+
data_batches: list[WorkerInputItem],
367+
pack_max_length: int,
368+
rollout_idx: int,
369+
enable_dp_balance: bool = True,
370+
) -> list[WorkerLogItem]:
371+
self._set_data_batches_properties(data_batches)
319372

320373
world_size = len(self.workers)
321374
dp_size = world_size // self.data_replicate_size
322375
assert world_size % self.data_replicate_size == 0, "world_size must be divisible by data_replicate_size"
323376
optimizer_steps = self.worker_cfg.optimizer_steps
324377

378+
batches_per_dp_group: list[list[WorkerInputItem]]
325379
if enable_dp_balance:
326380
# 按照 dp_size 对数据进行重新分配,保证每个 dp rank 上的 token 数量大致相同
327381
batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)
328382
else:
329383
batches_per_dp_group = np.array_split(data_batches, dp_size)
330384
tokens_in_partition = []
331385
for batch in batches_per_dp_group:
332-
tokens_in_partition.append(sum(data["seq_ctx"].input_ids.numel() for data in batch))
333-
get_logger().info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}")
334-
335-
packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)]
336-
max_packs_per_card = [0] * optimizer_steps
386+
dp_group_total_tokens = 0
387+
for data in batch:
388+
dp_group_total_tokens += data["seq_ctx"].input_ids.numel() # type: ignore[union-attr]
389+
tokens_in_partition.append(dp_group_total_tokens)
390+
self.logger.info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}")
391+
392+
packed_data_batches: list[list[list[WorkerInputItem]]] = [
393+
[[] for _ in range(optimizer_steps)] for _ in range(dp_size)
394+
]
395+
max_packs_per_step = [0] * optimizer_steps
337396

338397
for dp_rank, dp_worker_data_batches in enumerate(batches_per_dp_group):
339-
# 每个worker 内部按照optimizer_steps将token均分
398+
# 每个worker内部按照optimizer_steps将token均分
340399
if enable_dp_balance:
341400
random.shuffle(dp_worker_data_batches)
342-
mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)
401+
mini_batch_for_steps: list[list[WorkerInputItem]] = self._balance_split_batch(
402+
dp_worker_data_batches, optimizer_steps
403+
)
343404

344405
for step_idx, step_mini_batch in enumerate(mini_batch_for_steps):
345-
# pack
346-
pack_mini_batch = self._pack(step_mini_batch, pack_max_length)
347-
if len(pack_mini_batch) > max_packs_per_card[step_idx]:
348-
max_packs_per_card[step_idx] = len(pack_mini_batch)
349-
350-
for pack in pack_mini_batch:
351-
seq_ctx_list = [item["seq_ctx"] for item in pack]
352-
label_list = [item["shifted_labels"] for item in pack]
353-
advantage_list = [torch.tensor([item["advantage"]]).float().unsqueeze(0) for item in pack]
354-
rollout_logprobs_list = [
355-
item["rollout_logprobs"] if has_rollout_logprobs else None for item in pack
356-
]
357-
padding_len = pack_max_length - sum([item["seq_ctx"].input_ids.numel() for item in pack])
358-
if padding_len > 0:
359-
padding_sample = self._create_padding_sample(
360-
padding_len,
361-
pack_max_length,
362-
is_qwen3_vl=is_qwen3_vl,
363-
has_rollout_routed_experts=has_rollout_routed_experts,
364-
has_rollout_logprobs=has_rollout_logprobs,
365-
n_routed_experts=n_routed_experts,
366-
)
367-
seq_ctx_list.append(padding_sample["seq_ctx"])
368-
label_list.append(padding_sample["shifted_labels"])
369-
advantage_list.append(padding_sample["advantages"])
370-
rollout_logprobs_list.append(padding_sample["rollout_logprobs"])
371-
372-
packed_seq_ctx = SequenceContext.pack(seq_ctx_list)
373-
packed_shifted_labels = torch.cat(label_list, dim=1)
374-
cu_seq_lens_q = packed_seq_ctx.cu_seq_lens_q
375-
packed_num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
376-
packed_advantages = torch.cat(advantage_list, dim=1)
377-
packed_advantages = torch.repeat_interleave(packed_advantages, packed_num_tokens, dim=1)
378-
if has_rollout_logprobs:
379-
cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list]
380-
packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1)
381-
else:
382-
packed_rollout_logprobs = None
383-
packed_data_batches[dp_rank][step_idx].append(
384-
{
385-
"seq_ctx": packed_seq_ctx,
386-
"shifted_labels": packed_shifted_labels,
387-
"advantages": packed_advantages,
388-
"rollout_logprobs": packed_rollout_logprobs,
389-
}
390-
)
406+
# rearrange mini batch to fit into packs of pack_max_length
407+
batch4pack_list: list[list[WorkerInputItem]] = self._rearrange_batch_for_pack(
408+
step_mini_batch, pack_max_length
409+
)
410+
if len(batch4pack_list) > max_packs_per_step[step_idx]:
411+
max_packs_per_step[step_idx] = len(batch4pack_list)
391412

392-
get_logger().info(f"Gradient accumulation steps: {max_packs_per_card}")
393-
# padding for each worker to have same number of packs
394-
for dp_rank in range(dp_size):
395-
for step_idx in range(optimizer_steps):
396-
max_packs = max_packs_per_card[step_idx]
397-
num_current_packs = len(packed_data_batches[dp_rank][step_idx])
398-
num_padding_packs = max_packs - num_current_packs
399-
400-
if num_padding_packs > 0:
401-
padding_sample = self._create_padding_sample(
402-
pack_max_length,
403-
pack_max_length,
404-
is_qwen3_vl=is_qwen3_vl,
405-
has_rollout_routed_experts=has_rollout_routed_experts,
406-
has_rollout_logprobs=has_rollout_logprobs,
407-
n_routed_experts=n_routed_experts,
408-
)
409-
padding_samples = [padding_sample for _ in range(num_padding_packs)]
410-
packed_data_batches[dp_rank][step_idx].extend(padding_samples)
413+
for batch4pack in batch4pack_list:
414+
# pad and pack batches into a single optimizer step pack
415+
step_pack = self._pad_and_pack_batches(batch4pack, pack_max_length)
416+
packed_data_batches[dp_rank][step_idx].append(step_pack)
417+
418+
self.logger.info(f"Gradient accumulation for each optimizer steps: {max_packs_per_step}")
419+
420+
# padding for each worker to have same number of packs in each optimizer step
421+
for step_idx in range(optimizer_steps):
422+
max_packs = max_packs_per_step[step_idx]
423+
self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)
411424

412425
handles = []
413426
for worker_idx, worker in enumerate(self.workers):

0 commit comments

Comments
 (0)