Skip to content

Commit 2dfbb8b

Browse files
committed
add more clear type annotations in TrainController
1 parent 3be2fb4 commit 2dfbb8b

File tree

3 files changed

+136
-123
lines changed

3 files changed

+136
-123
lines changed

xtuner/v1/rl/base/controller.py

Lines changed: 132 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import math
22
import random
3-
from typing import Literal, TypedDict, cast
3+
from pathlib import Path
4+
from typing import Literal, cast
45

56
import numpy as np
67
import ray
@@ -13,14 +14,7 @@
1314
from xtuner.v1.train.trainer import LoadCheckpointConfig
1415
from xtuner.v1.utils import get_logger, ray_method
1516

16-
from .worker import TrainingWorker
17-
18-
19-
class ColateItem(TypedDict):
20-
seq_ctx: SequenceContext
21-
shifted_labels: torch.Tensor
22-
advantage: float
23-
rollout_logprobs: torch.Tensor | None
17+
from .worker import TrainingWorker, WorkerInputItem
2418

2519

2620
class RawTrainingController:
@@ -32,6 +26,17 @@ def __init__(self, workers: list[TrainingWorker]) -> None:
3226
self.workers[0].get_data_replicate_size.remote(),
3327
]
3428
self.model_cfg, self.worker_cfg, self.data_replicate_size = ray.get(refs)
29+
log_dir = self.worker_cfg.log_dir
30+
self.log_dir = None
31+
if log_dir is not None:
32+
self.log_dir = Path(log_dir) if isinstance(log_dir, str) else log_dir
33+
self.logger = get_logger(log_dir=self.log_dir, tag="TrainingController")
34+
else:
35+
self.logger = get_logger()
36+
self.is_qwen3_vl = False
37+
self.has_rollout_routed_experts = False
38+
self.has_rollout_logprobs = False
39+
self.n_routed_experts = None
3540

3641
# TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack
3742
def _get_pack_infos(self, dataset, num_tokens, target, random=None):
@@ -96,7 +101,7 @@ def _packing(self, data_batches, pack_max_length, language_cfg):
96101
pad_len = pack_max_length - total_len
97102
seq_ctx_list = [data_batches[i]["seq_ctx"] for i in indices]
98103
label_list = [data_batches[i]["shifted_labels"] for i in indices]
99-
advantage_list = [data_batches[i]["advantage"] for i in indices]
104+
advantage_list = [data_batches[i]["advantages"] for i in indices]
100105

101106
rollout_logprobs_list = None
102107
if "rollout_logprobs" in data_batches[0] and data_batches[0]["rollout_logprobs"] is not None:
@@ -173,10 +178,10 @@ def _grouped_by_max_length(self, packed_data_batches):
173178
# 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。
174179
return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True)
175180

176-
def _balance_split_batch(self, data_batches, partition_size):
181+
def _balance_split_batch(self, data_batches: list[WorkerInputItem], partition_size) -> list[list[WorkerInputItem]]:
177182
"""Reorder the data on single controller such that each dp rank gets
178183
similar total tokens."""
179-
global_seqlen_lst = [data["seq_ctx"].input_ids.numel() for data in data_batches]
184+
global_seqlen_lst = [data["seq_ctx"].input_ids.numel() for data in data_batches] # type: ignore[union-attr]
180185
global_partition_lst = get_seqlen_balanced_partitions(
181186
global_seqlen_lst, k_partitions=partition_size, equal_size=True
182187
)
@@ -189,16 +194,12 @@ def _balance_split_batch(self, data_batches, partition_size):
189194
get_logger().info(f"Balanced split into {partition_size} partitions with tokens: {tokens_in_partition}")
190195
return balanced_batches
191196

192-
def _create_padding_sample(
197+
def _create_padding_item(
193198
self,
194199
pad_len: int,
195200
pack_max_length: int,
196-
is_qwen3_vl: bool = False,
197-
has_rollout_routed_experts: bool = False,
198-
has_rollout_logprobs: bool = True,
199-
n_routed_experts: int | None = None,
200201
split_size: int = 1024,
201-
):
202+
) -> WorkerInputItem:
202203
# padding input_ids
203204
pad_tokens = tuple(
204205
torch.zeros(1, split_size, dtype=torch.long, device="cpu") for _ in range(pad_len // split_size)
@@ -210,7 +211,7 @@ def _create_padding_sample(
210211
pad_seq_ctx.num_padding = pad_len
211212

212213
# padding mm positions_ids
213-
if is_qwen3_vl:
214+
if self.is_qwen3_vl:
214215
_position_ids_list = []
215216
for pad_token in pad_tokens:
216217
_position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1)
@@ -220,17 +221,17 @@ def _create_padding_sample(
220221
pad_seq_ctx.position_ids = position_ids
221222

222223
# padding rollout routed experts
223-
if has_rollout_routed_experts:
224-
assert n_routed_experts, "n_routed_experts must be provided when has_rollout_routed_experts is True"
224+
if self.has_rollout_routed_experts:
225+
assert self.n_routed_experts, "n_routed_experts must be provided when has_rollout_routed_experts is True"
225226
if pad_len == pack_max_length:
226227
pad_rand_index = torch.randint(
227228
low=0, high=1, size=(1, 1, 1)
228229
) # add dummy data, true data will be initialized in train worker.fit
229230
else:
230-
pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1))
231+
pad_rand_index = torch.randint(low=0, high=self.n_routed_experts, size=(pad_len, 1, 1))
231232
pad_seq_ctx.rollout_routed_experts = pad_rand_index
232233

233-
pad_labels = torch.full((1, pad_len), -100, dtype=torch.long, device="cpu")
234+
pad_labels = cast(torch.LongTensor, torch.full((1, pad_len), -100, dtype=torch.int64, device="cpu"))
234235
pad_advantage_length = pack_max_length if pad_len == pack_max_length else math.ceil(pad_len / 1024)
235236
pad_advantage = torch.full(
236237
(1, pad_advantage_length),
@@ -239,24 +240,27 @@ def _create_padding_sample(
239240
device="cpu",
240241
)
241242
pad_rollout_logprobs = (
242-
torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") if has_rollout_logprobs else None
243+
torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") if self.has_rollout_logprobs else None
243244
)
244245

245-
return {
246+
padding_item: WorkerInputItem = {
246247
"seq_ctx": pad_seq_ctx,
247248
"shifted_labels": pad_labels,
248249
"advantages": pad_advantage,
249250
"rollout_logprobs": pad_rollout_logprobs,
250251
}
252+
return padding_item
251253

252-
def _pack(self, mini_batch, pack_max_length):
254+
def _rearrange_batch_for_pack(
255+
self, mini_batch: list[WorkerInputItem], pack_max_length: int
256+
) -> list[list[WorkerInputItem]]:
253257
assert len(mini_batch) > 0, "mini_batch should not be empty"
254258
seqlen_list = []
255259
for data in mini_batch:
256-
assert data["seq_ctx"].input_ids.numel() <= pack_max_length, (
257-
f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {pack_max_length}"
260+
assert data["seq_ctx"].input_ids.numel() <= pack_max_length, ( # type: ignore[union-attr]
261+
f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {pack_max_length}" # type: ignore[union-attr]
258262
)
259-
seqlen_list.append(data["seq_ctx"].input_ids.numel())
263+
seqlen_list.append(data["seq_ctx"].input_ids.numel()) # type: ignore[union-attr]
260264
total_length = sum(seqlen_list)
261265

262266
if total_length <= pack_max_length:
@@ -273,15 +277,10 @@ def _pack(self, mini_batch, pack_max_length):
273277
packed_mini_batches.append(packed_batch)
274278
return packed_mini_batches
275279

276-
def _get_data_batches_properties(self, data_batches: list[ColateItem]):
280+
def _set_data_batches_properties(self, data_batches: list[WorkerInputItem]):
277281
"""Extract properties from the first element of data_batches."""
278282
if not data_batches:
279-
return {
280-
"is_qwen3_vl": False,
281-
"has_rollout_routed_experts": False,
282-
"has_rollout_logprobs": False,
283-
"n_routed_experts": None,
284-
}
283+
return
285284

286285
first_item = data_batches[0]
287286
seq_ctx = first_item["seq_ctx"]
@@ -296,114 +295,128 @@ def _get_data_batches_properties(self, data_batches: list[ColateItem]):
296295
if isinstance(self.model_cfg, BaseComposeConfig):
297296
language_cfg = self.model_cfg.text_config
298297

299-
return {
300-
"is_qwen3_vl": is_qwen3_vl,
301-
"has_rollout_routed_experts": has_rollout_routed_experts,
302-
"has_rollout_logprobs": has_rollout_logprobs,
303-
"n_routed_experts": language_cfg.n_routed_experts if language_cfg is not None else None,
298+
self.is_qwen3_vl = is_qwen3_vl
299+
self.has_rollout_routed_experts = has_rollout_routed_experts
300+
self.has_rollout_logprobs = has_rollout_logprobs
301+
self.n_routed_experts = language_cfg.n_routed_experts if language_cfg is not None else None
302+
303+
def _pad_and_pack_batches(self, batch4pack: list[WorkerInputItem], pack_max_length: int) -> WorkerInputItem:
304+
seq_ctx_list = [item["seq_ctx"] for item in batch4pack]
305+
label_list = [item["shifted_labels"] for item in batch4pack]
306+
advantage_list = [torch.tensor([item["advantages"]]).float().unsqueeze(0) for item in batch4pack]
307+
rollout_logprobs_list = [
308+
item["rollout_logprobs"] if self.has_rollout_logprobs else None for item in batch4pack
309+
]
310+
cur_length = 0
311+
for item in batch4pack:
312+
cur_length += item["seq_ctx"].input_ids.numel() # type: ignore[union-attr]
313+
padding_len = pack_max_length - cur_length
314+
315+
if padding_len > 0:
316+
padding_item = self._create_padding_item(padding_len, pack_max_length)
317+
seq_ctx_list.append(padding_item["seq_ctx"])
318+
label_list.append(padding_item["shifted_labels"])
319+
advantage_list.append(padding_item["advantages"])
320+
rollout_logprobs_list.append(padding_item["rollout_logprobs"])
321+
322+
packed_seq_ctx = SequenceContext.pack(seq_ctx_list)
323+
packed_shifted_labels = torch.cat(label_list, dim=1) # type: ignore[arg-type]
324+
packed_shifted_labels = cast(torch.LongTensor, packed_shifted_labels)
325+
cu_seq_lens_q = packed_seq_ctx.cu_seq_lens_q
326+
packed_num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
327+
packed_advantages = torch.cat(advantage_list, dim=1)
328+
packed_advantages = torch.repeat_interleave(packed_advantages, packed_num_tokens, dim=1)
329+
if self.has_rollout_logprobs:
330+
cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list]
331+
packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1)
332+
else:
333+
packed_rollout_logprobs = None
334+
335+
optimizer_step_packs: WorkerInputItem = {
336+
"seq_ctx": packed_seq_ctx,
337+
"shifted_labels": packed_shifted_labels,
338+
"advantages": packed_advantages,
339+
"rollout_logprobs": packed_rollout_logprobs,
304340
}
341+
return optimizer_step_packs
342+
343+
def _pad_to_max_packs_across_workes(
344+
self,
345+
packed_data_batches: list[list[list[WorkerInputItem]]],
346+
step_idx: int,
347+
max_packs: int,
348+
pack_max_length: int,
349+
):
350+
for dp_rank in range(len(packed_data_batches)):
351+
num_current_packs = len(packed_data_batches[dp_rank][step_idx])
352+
num_padding_packs = max_packs - num_current_packs
353+
354+
if num_padding_packs > 0:
355+
padding_item = self._create_padding_item(pack_max_length, pack_max_length)
356+
padding_items = [padding_item for _ in range(num_padding_packs)]
357+
packed_data_batches[dp_rank][step_idx].extend(padding_items)
305358

306359
@ray_method
307360
def fit(
308-
self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int, enable_dp_balance: bool = True
361+
self,
362+
data_batches: list[WorkerInputItem],
363+
pack_max_length: int,
364+
rollout_idx: int,
365+
enable_dp_balance: bool = True,
309366
):
310-
batch_props = self._get_data_batches_properties(data_batches)
311-
is_qwen3_vl = batch_props["is_qwen3_vl"]
312-
has_rollout_routed_experts = batch_props["has_rollout_routed_experts"]
313-
has_rollout_logprobs = batch_props["has_rollout_logprobs"]
314-
n_routed_experts = batch_props["n_routed_experts"]
367+
self._set_data_batches_properties(data_batches)
315368

316369
world_size = len(self.workers)
317370
dp_size = world_size // self.data_replicate_size
318371
assert world_size % self.data_replicate_size == 0, "world_size must be divisible by data_replicate_size"
319372
optimizer_steps = self.worker_cfg.optimizer_steps
320373

374+
batches_per_dp_group: list[list[WorkerInputItem]]
321375
if enable_dp_balance:
322376
# 按照 dp_size 对数据进行重新分配,保证每个 dp rank 上的 token 数量大致相同
323377
batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)
324378
else:
325379
batches_per_dp_group = np.array_split(data_batches, dp_size)
326380
tokens_in_partition = []
327381
for batch in batches_per_dp_group:
328-
tokens_in_partition.append(sum(data["seq_ctx"].input_ids.numel() for data in batch))
329-
get_logger().info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}")
330-
331-
packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)]
332-
max_packs_per_card = [0] * optimizer_steps
382+
dp_group_total_tokens = 0
383+
for data in batch:
384+
dp_group_total_tokens += data["seq_ctx"].input_ids.numel() # type: ignore[union-attr]
385+
tokens_in_partition.append(dp_group_total_tokens)
386+
self.logger.info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}")
387+
388+
packed_data_batches: list[list[list[WorkerInputItem]]] = [
389+
[[] for _ in range(optimizer_steps)] for _ in range(dp_size)
390+
]
391+
max_packs_per_step = [0] * optimizer_steps
333392

334393
for dp_rank, dp_worker_data_batches in enumerate(batches_per_dp_group):
335-
# 每个worker 内部按照optimizer_steps将token均分
394+
# 每个worker内部按照optimizer_steps将token均分
336395
if enable_dp_balance:
337396
random.shuffle(dp_worker_data_batches)
338-
mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)
397+
mini_batch_for_steps: list[list[WorkerInputItem]] = self._balance_split_batch(
398+
dp_worker_data_batches, optimizer_steps
399+
)
339400

340401
for step_idx, step_mini_batch in enumerate(mini_batch_for_steps):
341-
# pack
342-
pack_mini_batch = self._pack(step_mini_batch, pack_max_length)
343-
if len(pack_mini_batch) > max_packs_per_card[step_idx]:
344-
max_packs_per_card[step_idx] = len(pack_mini_batch)
345-
346-
for pack in pack_mini_batch:
347-
seq_ctx_list = [item["seq_ctx"] for item in pack]
348-
label_list = [item["shifted_labels"] for item in pack]
349-
advantage_list = [torch.tensor([item["advantage"]]).float().unsqueeze(0) for item in pack]
350-
rollout_logprobs_list = [
351-
item["rollout_logprobs"] if has_rollout_logprobs else None for item in pack
352-
]
353-
padding_len = pack_max_length - sum([item["seq_ctx"].input_ids.numel() for item in pack])
354-
if padding_len > 0:
355-
padding_sample = self._create_padding_sample(
356-
padding_len,
357-
pack_max_length,
358-
is_qwen3_vl=is_qwen3_vl,
359-
has_rollout_routed_experts=has_rollout_routed_experts,
360-
has_rollout_logprobs=has_rollout_logprobs,
361-
n_routed_experts=n_routed_experts,
362-
)
363-
seq_ctx_list.append(padding_sample["seq_ctx"])
364-
label_list.append(padding_sample["shifted_labels"])
365-
advantage_list.append(padding_sample["advantages"])
366-
rollout_logprobs_list.append(padding_sample["rollout_logprobs"])
367-
368-
packed_seq_ctx = SequenceContext.pack(seq_ctx_list)
369-
packed_shifted_labels = torch.cat(label_list, dim=1)
370-
cu_seq_lens_q = packed_seq_ctx.cu_seq_lens_q
371-
packed_num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1]
372-
packed_advantages = torch.cat(advantage_list, dim=1)
373-
packed_advantages = torch.repeat_interleave(packed_advantages, packed_num_tokens, dim=1)
374-
if has_rollout_logprobs:
375-
cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list]
376-
packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1)
377-
else:
378-
packed_rollout_logprobs = None
379-
packed_data_batches[dp_rank][step_idx].append(
380-
{
381-
"seq_ctx": packed_seq_ctx,
382-
"shifted_labels": packed_shifted_labels,
383-
"advantages": packed_advantages,
384-
"rollout_logprobs": packed_rollout_logprobs,
385-
}
386-
)
402+
# rearrange mini batch to fit into packs of pack_max_length
403+
batch4pack_list: list[list[WorkerInputItem]] = self._rearrange_batch_for_pack(
404+
step_mini_batch, pack_max_length
405+
)
406+
if len(batch4pack_list) > max_packs_per_step[step_idx]:
407+
max_packs_per_step[step_idx] = len(batch4pack_list)
387408

388-
get_logger().info(f"Gradient accumulation steps: {max_packs_per_card}")
389-
# padding for each worker to have same number of packs
390-
for dp_rank in range(dp_size):
391-
for step_idx in range(optimizer_steps):
392-
max_packs = max_packs_per_card[step_idx]
393-
num_current_packs = len(packed_data_batches[dp_rank][step_idx])
394-
num_padding_packs = max_packs - num_current_packs
395-
396-
if num_padding_packs > 0:
397-
padding_sample = self._create_padding_sample(
398-
pack_max_length,
399-
pack_max_length,
400-
is_qwen3_vl=is_qwen3_vl,
401-
has_rollout_routed_experts=has_rollout_routed_experts,
402-
has_rollout_logprobs=has_rollout_logprobs,
403-
n_routed_experts=n_routed_experts,
404-
)
405-
padding_samples = [padding_sample for _ in range(num_padding_packs)]
406-
packed_data_batches[dp_rank][step_idx].extend(padding_samples)
409+
for batch4pack in batch4pack_list:
410+
# pad and pack batches into a single optimizer step pack
411+
step_pack = self._pad_and_pack_batches(batch4pack, pack_max_length)
412+
packed_data_batches[dp_rank][step_idx].append(step_pack)
413+
414+
self.logger.info(f"Gradient accumulation for each optimizer steps: {max_packs_per_step}")
415+
416+
# padding for each worker to have same number of packs in each optimizer step
417+
for step_idx in range(optimizer_steps):
418+
max_packs = max_packs_per_step[step_idx]
419+
self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)
407420

408421
handles = []
409422
for worker_idx, worker in enumerate(self.workers):

0 commit comments

Comments
 (0)