diff --git a/tests/ray/test_grpo_train.py b/tests/ray/test_grpo_train.py index 720d334e5..b06685a62 100644 --- a/tests/ray/test_grpo_train.py +++ b/tests/ray/test_grpo_train.py @@ -55,7 +55,6 @@ def setUp(self): rewards = [item['reward'] for item in group] rewards = torch.tensor(rewards, dtype=torch.float32) advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8) - for i in range(self.prompt_repeat_k): item = group[i] response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist() @@ -67,7 +66,7 @@ def setUp(self): dict( seq_ctx=SequenceContext.from_input_ids((input_ids, ), device="cpu"), shifted_labels=shifted_labels, - advantage=advantages[i].item(), + advantages=advantages[i], ) ) self.data_batches = data_batches @@ -126,8 +125,125 @@ def build_train_controller(self): ray.get(train_controller.__ray_ready__.remote()) return train_controller - def test_grpo_train_and_save(self): + # def test_grpo_train_and_save(self): + # train_controller = self.build_train_controller() + # ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=8192, rollout_idx=0)) + # save_path = os.path.join(self.temp_dir, "hf_test") + # ray.get(train_controller.save_hf.remote(str(save_path))) + + def _create_dummy_item(self, length: int): + """Helper to create a dummy WorkerInputItem""" + input_ids = torch.ones(1, length, dtype=torch.long) + cu_seq_lens_q = torch.tensor([0, length], dtype=torch.int32) + cu_seq_lens_k = torch.tensor([0, length], dtype=torch.int32) + max_length_q = torch.tensor(length, dtype=torch.int32) + max_length_k = torch.tensor(length, dtype=torch.int32) + seq_ctx = SequenceContext( + input_ids=input_ids, + cu_seq_lens_q=cu_seq_lens_q, + cu_seq_lens_k=cu_seq_lens_k, + max_length_q=max_length_q, + max_length_k=max_length_k, + num_padding=0, + device="cpu", + ) + return { + "seq_ctx": seq_ctx, + "shifted_labels": torch.ones(1, length, dtype=torch.long), + "advantages": torch.rand(1, 1, dtype=torch.float), + "rollout_logprobs": torch.ones(1, length, dtype=torch.float), + } + + def test_controller_logic(self): + """ + Unit tests for RawTrainingController internal logic using the real Ray actor: + - _balance_split_batch + - _create_padding_item + - _rearrange_batch_for_pack + - _pad_and_pack_batches + """ + # 1. Build the real train controller train_controller = self.build_train_controller() - ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0)) - save_path = os.path.join(self.temp_dir, "hf_test") - ray.get(train_controller.save_hf.remote(str(save_path))) + pack_max_length = 100 + + # --- Test 1: _balance_split_batch --- + print("Testing _balance_split_batch...") + # Input: 4 items with lengths 10, 20, 30, 40 + items = [self._create_dummy_item(l) for l in [10, 20, 30, 40]] + dp_size = 2 + + # Call remote method + # 10, 20, 30, 40 -> sum 100 -> avg 50. + # Expected split: [10, 40] (sum 50) and [20, 30] (sum 50) + result = ray.get(train_controller._balance_split_batch.remote(items, dp_size)) + + self.assertEqual(len(result), 2) + self.assertEqual(len(result[0]), 2) + self.assertEqual(len(result[1]), 2) + + # Verify balance + len_group0 = sum(item["seq_ctx"].input_ids.shape[1] for item in result[0]) + len_group1 = sum(item["seq_ctx"].input_ids.shape[1] for item in result[1]) + self.assertEqual(len_group0, 50) + self.assertEqual(len_group1, 50) + + # --- Test 2: _rearrange_batch_for_pack --- + print("Testing _rearrange_batch_for_pack...") + # Input: [40, 40, 30], max=100. With get_seqlen_balanced_partitions, it should be packed as [40, 30] and [40] + items_pack = [self._create_dummy_item(l) for l in [40, 40, 30]] + batches = ray.get(train_controller._rearrange_batch_for_pack.remote(items_pack, pack_max_length)) + + self.assertEqual(len(batches), 2) + self.assertEqual(len(batches[0]), 2) # 40 + 30 = 70 + self.assertEqual(len(batches[1]), 1) # 40 + self.assertEqual(batches[0][0]["seq_ctx"].input_ids.shape[1] + batches[0][1]["seq_ctx"].input_ids.shape[1], 70) + self.assertEqual(batches[1][0]["seq_ctx"].input_ids.shape[1], 40) + # --- Test 3: _pad_and_pack_batches --- + print("Testing _pad_and_pack_batches...") + # Input: First batch with length 70. Should pad 30 to reach 100. Second batch with length 40, should pad 60 to reach 100. + for idx, batch4pack_list in enumerate(batches): + packed_item = ray.get(train_controller._pad_and_pack_batches.remote(batch4pack_list, pack_max_length)) + # Check total length + self.assertEqual(packed_item["seq_ctx"].input_ids.shape[1], pack_max_length) + # idx == 0: + if idx == 0: + # Check cu_seq_lens_q: [0, 40, 70, 100] + expected_cu_lens = torch.tensor([0, 40, 70, 100], dtype=torch.int32) + self.assertTrue(torch.equal(packed_item["seq_ctx"].cu_seq_lens_q, expected_cu_lens)) + # Check padding labels are -100 + self.assertTrue(torch.all(packed_item["shifted_labels"][0, 70:] == -100)) + if idx == 1: + # Check cu_seq_lens_q: [0, 40, 100] + expected_cu_lens = torch.tensor([0, 40, 100], dtype=torch.int32) + self.assertTrue(torch.equal(packed_item["seq_ctx"].cu_seq_lens_q, expected_cu_lens)) + # Check padding labels are -100 + self.assertTrue(torch.all(packed_item["shifted_labels"][0, 40:] == -100)) + + # --- Test 4: _pad_to_max_packs_across_workes --- + pack_dummy = {"dummy": "pack"} + packed_data_batches = [ + [[pack_dummy, pack_dummy]], # Worker 0: 2 packs + [[pack_dummy]] # Worker 1: 1 pack + ] + # Execute the function locally + packed_data_batches = ray.get(train_controller._pad_to_max_packs_across_workes.remote( + packed_data_batches, 0, 2, pack_max_length + )) + # Verification + # Worker 0 should still have 2 packs + self.assertEqual(len(packed_data_batches[0][0]), 2) + + # Worker 1 should now have 2 packs (1 original + 1 padding) + self.assertEqual(len(packed_data_batches[1][0]), 2) + + # Verify the added item is a padding item + added_pack = packed_data_batches[1][0][1] + # Since we used the real _create_padding_item, it should have the correct structure + self.assertIn("seq_ctx", added_pack) + self.assertIn("shifted_labels", added_pack) + self.assertEqual(added_pack["seq_ctx"].input_ids.shape[1], pack_max_length) + self.assertTrue(torch.all(added_pack["shifted_labels"] == -100)) + print("All controller logic tests passed!") + +if __name__ == "__main__": + unittest.main() \ No newline at end of file diff --git a/xtuner/v1/rl/base/__init__.py b/xtuner/v1/rl/base/__init__.py index d75603b57..e57889f7a 100644 --- a/xtuner/v1/rl/base/__init__.py +++ b/xtuner/v1/rl/base/__init__.py @@ -1,6 +1,13 @@ -from .controller import TrainingController, TrainingControllerProxy +from .controller import TrainingController, TrainingControllerProxy, TrainingLogInfo from .loss import BaseRLLossConfig, RLLossContextInputItem -from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem +from .worker import ( + TrainingWorker, + TrainingWorkerClass, + TrainingWorkerProxy, + WorkerConfig, + WorkerInputItem, + WorkerLogItem, +) __all__ = [ @@ -13,4 +20,6 @@ "BaseRLLossConfig", "RLLossContextInputItem", "WorkerLogItem", + "WorkerInputItem", + "TrainingLogInfo", ] diff --git a/xtuner/v1/rl/base/controller.py b/xtuner/v1/rl/base/controller.py index 60d6e6a0d..baa713740 100644 --- a/xtuner/v1/rl/base/controller.py +++ b/xtuner/v1/rl/base/controller.py @@ -1,267 +1,83 @@ -import math import os -from typing import Literal, TypedDict +import time +from pathlib import Path +from typing import Literal import ray import torch from ray.actor import ActorProxy +from typing_extensions import TypedDict -from xtuner.v1.data_proto.sequence_context import SequenceContext -from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.rl.pack import DataBatchPacker from xtuner.v1.train.trainer import LoadCheckpointConfig -from xtuner.v1.utils import ray_method +from xtuner.v1.utils import get_logger, ray_method -from .worker import TrainingWorker, WorkerLogItem +from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem TRAIN_RAY_GET_TIMEOUT = os.getenv("XTUNER_TRAIN_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours -class ColateItem(TypedDict): - seq_ctx: SequenceContext - shifted_labels: torch.Tensor - advantage: float - rollout_logprobs: torch.Tensor | None +class TrainingLogInfo(TypedDict): + worker_log_infos: list[WorkerLogItem] + padding_tokens: int + pack_time: float + train_time: float class RawTrainingController: def __init__(self, workers: list[TrainingWorker]) -> None: self.workers = workers - - # TODO(hha): 这个逻辑不够通用,应该复用 sft 函数,从而支持 expand soft pack - def _get_pack_infos(self, dataset, num_tokens, target, random=None): - inds = list(range(len(dataset))) - if random is not None: - random.shuffle(inds) - - item_buffer = [] - length_buffer = [] - longest = 0 - - pack_infos = [] - for shfl_i in inds: - if num_tokens[shfl_i] + sum(length_buffer) <= target: - item_buffer.append(shfl_i) - length_buffer.append(num_tokens[shfl_i]) - longest = max(longest, num_tokens[shfl_i]) - else: - if len(item_buffer) > 0: - info = { - "indices": item_buffer, - "longest": int(longest), - } - pack_infos.append(info) - - item_buffer = [shfl_i] - length_buffer = [num_tokens[shfl_i]] - longest = num_tokens[shfl_i] - - if len(item_buffer) > 0: - info = { - "indices": item_buffer, - "longest": int(longest), - } - - pack_infos.append(info) - - return pack_infos - - # TODO(hha): 这个逻辑不够通用,和模型绑定了 - def _packing(self, data_batches, pack_max_length, language_cfg): - pack_infos = self._get_pack_infos( - data_batches, - [data["seq_ctx"].input_ids.numel() for data in data_batches], - pack_max_length, + refs = [ + self.workers[0].get_model_cfg.remote(), + self.workers[0].get_worker_cfg.remote(), + self.workers[0].get_data_replicate_size.remote(), + ] + self.model_cfg, self.worker_cfg, self.data_replicate_size = ray.get(refs) + self.pack_max_length = self.worker_cfg.pack_max_length + self.pack_strategy = self.worker_cfg.pack_strategy + self.data_packer = DataBatchPacker( + pack_max_length=self.pack_max_length, + world_size=len(self.workers), + data_replicate_size=self.data_replicate_size, + optimizer_steps=self.worker_cfg.optimizer_steps, + pack_strategy=self.pack_strategy, + worker_log_dir=self.worker_cfg.log_dir, ) - packed_data_batches = [] - - is_qwen3_vl = False - if len(data_batches[0]["seq_ctx"].position_ids.shape) == 3: - is_qwen3_vl = True - - has_rollout_routed_experts = False - if data_batches[0]["seq_ctx"].rollout_routed_experts is not None: - assert language_cfg is not None - has_rollout_routed_experts = True - n_routed_experts = language_cfg.n_routed_experts - - for pack_info in pack_infos: - indices = pack_info["indices"] - total_len = sum([data_batches[i]["seq_ctx"].input_ids.shape[1] for i in indices]) - pad_len = pack_max_length - total_len - seq_ctx_list = [data_batches[i]["seq_ctx"] for i in indices] - label_list = [data_batches[i]["shifted_labels"] for i in indices] - advantage_list = [data_batches[i]["advantage"] for i in indices] - - rollout_logprobs_list = None - if "rollout_logprobs" in data_batches[0] and data_batches[0]["rollout_logprobs"] is not None: - rollout_logprobs_list = [data_batches[i]["rollout_logprobs"] for i in indices] - - if pad_len > 0: - # Reduce the attn calculation time by using multiple short sequence packs - pad_tokens = tuple( - torch.zeros(1, 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu") - for _ in range(pad_len // 1024) - ) - if pad_len % 1024 > 0: - pad_tokens = pad_tokens + ( - torch.zeros(1, pad_len % 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu"), - ) - pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") - pad_seq_ctx.num_padding = pad_len - pad_labels = torch.full( - (1, pad_len), - -100, - dtype=data_batches[0]["shifted_labels"].dtype, - device=data_batches[0]["shifted_labels"].device, - ) - if is_qwen3_vl: - _position_ids_list = [] - for pad_token in pad_tokens: - _position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1) - _position_ids_list.append(_position_ids) - pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1) - - if has_rollout_routed_experts: - pad_rand_index = torch.randint(low=0, high=n_routed_experts, size=(pad_len, 1, 1)) - pad_seq_ctx.rollout_routed_experts = pad_rand_index - - seq_ctx_list.append(pad_seq_ctx) - label_list.append(pad_labels) - advantage_list.extend( - [-100] * math.ceil(pad_len / 1024) - ) # can be any number, pad tokens are excluded from the calculation of the loss function. - - if rollout_logprobs_list is not None: - pad_rollout_logprobs = torch.zeros( - 1, - pad_len, - dtype=data_batches[0]["rollout_logprobs"].dtype, - device=data_batches[0]["shifted_labels"].device, - ) - rollout_logprobs_list.append(pad_rollout_logprobs) - - seq_ctx = SequenceContext.pack(seq_ctx_list) - shifted_labels = torch.cat(label_list, dim=1) # (1, max_len) - advantages = torch.tensor(advantage_list).float().unsqueeze(0) # (1, num_samples) - cu_seq_lens_q = seq_ctx.cu_seq_lens_q - num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] - advantages = torch.repeat_interleave(advantages, num_tokens, dim=1) # (1, max_len) - - rollout_logprobs = None - if rollout_logprobs_list is not None: - rollout_logprobs = torch.cat(rollout_logprobs_list, dim=1) # (1, max_len) - - packed_data_batches.append( - { - "seq_ctx": seq_ctx, - "shifted_labels": shifted_labels, - "advantages": advantages, - "rollout_logprobs": rollout_logprobs, - } - ) - return packed_data_batches - - def _grouped_by_max_length(self, packed_data_batches): - # sort 过后可能第一个 batch 会有很多 pad tokens,因为最后一个 pack 可能只有少量真实数据。 - # 比如组成了 16 个 pack,第 16 个 pack 可能只有几条真实数据,剩下的都是 pad tokens。 - # 排序后这条 pack 会被放在最前面,导致 rank0 的第一个 step 消耗的有效 token 数往往少于其他 rank,是正常现象。 - return sorted(packed_data_batches, key=lambda x: x["seq_ctx"].max_length_q, reverse=True) + log_dir = self.worker_cfg.log_dir + self.log_dir = None + if log_dir is not None: + self.log_dir = Path(log_dir) if isinstance(log_dir, str) else log_dir + self.logger = get_logger(log_dir=self.log_dir, tag="TrainingController") + else: + self.logger = get_logger() @ray_method - def fit(self, data_batches: list[ColateItem], pack_max_length: int, rollout_idx: int) -> list[WorkerLogItem]: - has_rollout_routed_experts = False - language_cfg = None - if data_batches[0]["seq_ctx"].rollout_routed_experts is not None: - model_cfg = ray.get(self.workers[0].get_model_cfg.remote()) # type: ignore[attr-defined] - has_rollout_routed_experts = True - language_cfg = model_cfg - if isinstance(model_cfg, BaseComposeConfig): - language_cfg = model_cfg.text_config - - packed_data_batches = self._packing(data_batches, pack_max_length, language_cfg) - # packed_data_batches = self._grouped_by_max_length(packed_data_batches) - - # TODO(hha): 这个逻辑不够通用,和模型绑定了 - is_qwen3_vl = False - if len(packed_data_batches[0]["seq_ctx"].position_ids.shape) == 3: - is_qwen3_vl = True - - # todo: support round up - num_packed_data_batches = len(packed_data_batches) - data_replicate_size = ray.get(self.workers[0].get_data_replicate_size.remote()) # type: ignore[attr-defined] - dp_size = len(self.workers) // data_replicate_size - pad_num = math.ceil(num_packed_data_batches / dp_size) * dp_size - num_packed_data_batches - if pad_num > 0: - # Reduce the attn calculation time by using multiple short sequence packs - assert data_batches[0]["seq_ctx"].input_ids is not None - pad_tokens = tuple( - torch.zeros(1, 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu") - for _ in range(pack_max_length // 1024) - ) - if pack_max_length % 1024 > 0: - assert data_batches[0]["seq_ctx"].input_ids is not None - pad_tokens = pad_tokens + ( - torch.zeros( - 1, pack_max_length % 1024, dtype=data_batches[0]["seq_ctx"].input_ids.dtype, device="cpu" - ), - ) - pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") # type: ignore - pad_seq_ctx.num_padding = pack_max_length - if is_qwen3_vl: - _position_ids_list = [] - for pad_token in pad_tokens: - _position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1) - _position_ids_list.append(_position_ids) - pad_seq_ctx.position_ids = torch.cat(_position_ids_list, dim=-1) # type: ignore - - pad_shifted_labels = torch.full( - (1, pack_max_length), - -100, - dtype=packed_data_batches[0]["shifted_labels"].dtype, - device="cpu", - ) - pad_advantages = torch.full( - (1, pack_max_length), - -100, - dtype=packed_data_batches[0]["advantages"].dtype, - device="cpu", - ) - - if has_rollout_routed_experts: - pad_rand_index = torch.randint( - low=0, - high=1, - size=(1, 1, 1), # add dummy data, true data will be initialized in train worker.fit - ) - pad_seq_ctx.rollout_routed_experts = pad_rand_index - - pad_rollout_logprobs = None - if "rollout_logprobs" in packed_data_batches[0] and packed_data_batches[0]["rollout_logprobs"] is not None: - pad_rollout_logprobs = torch.zeros( - 1, pack_max_length, dtype=packed_data_batches[0]["rollout_logprobs"].dtype, device="cpu" - ) - pad_data = { - "seq_ctx": pad_seq_ctx, - "shifted_labels": pad_shifted_labels, - "advantages": pad_advantages, - "rollout_logprobs": pad_rollout_logprobs, - } - pad_data_samples = [pad_data for _ in range(pad_num)] - packed_data_batches = packed_data_batches + pad_data_samples - - print(f"len(packed_data_batches): {len(packed_data_batches)}") - + def fit( + self, + data_batches: list[WorkerInputItem], + rollout_idx: int, + ) -> TrainingLogInfo: + start_time = time.perf_counter() + packed_data_batches, padding_tokens_num = self.data_packer.pack(data_batches) + pack_end_time = time.perf_counter() handles = [] for worker_idx, worker in enumerate(self.workers): handles.append( worker.fit.remote( # type: ignore[attr-defined] - data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size], + data_batches=packed_data_batches[worker_idx // self.data_replicate_size], rollout_idx=rollout_idx, ) ) - log_infos = ray.get(handles, timeout=TRAIN_RAY_GET_TIMEOUT) - return log_infos + train_end_time = time.perf_counter() + worker_log_infos = ray.get(handles) + train_log_info: TrainingLogInfo = { + "worker_log_infos": worker_log_infos, + "pack_time": pack_end_time - start_time, + "train_time": train_end_time - pack_end_time, + "padding_tokens": padding_tokens_num, + } + return train_log_info @ray_method def offload(self, target: Literal["model", "optimizer", "all"] = "all"): diff --git a/xtuner/v1/rl/base/worker.py b/xtuner/v1/rl/base/worker.py index fd715dcb9..416d380e5 100644 --- a/xtuner/v1/rl/base/worker.py +++ b/xtuner/v1/rl/base/worker.py @@ -1,10 +1,9 @@ import json -import math import os import time from itertools import chain from pathlib import Path -from typing import Dict, Iterable, List, TypeAlias, TypedDict, cast +from typing import Any, Dict, Iterable, List, Literal, TypeAlias, TypedDict, cast import ray import requests @@ -142,7 +141,7 @@ class WorkerConfig(BaseModel): log_dir: str | Path | None = None update_weight_bucket_size_in_gb: float = 0.5 # 512MB seed: None | int = None # if None, use RLTrainer seed - + pack_strategy: Literal["greedy", "balance", "native"] = "greedy" # sft config sft_dataloader_cfg: DataloaderConfig | None = None sft_global_batch_size: int = -1 @@ -412,122 +411,81 @@ def _get_rl_other_log(self, other_log: OtherLog) -> RLOtherLog: } return rl_other_log - @ray_method - def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLogItem: - # NOTE: sglang会清除logger handle, 重新创建 - self.logger = get_logger(log_dir=self.log_dir, tag="TrainingWorker") - loss_cfg = self.config.loss_cfg - num_batches = len(data_batches) - iters_per_step = math.ceil(num_batches / self._optimizer_steps) - if num_batches < self._optimizer_steps: - self.logger.info( - f"Optimizer only step once because num_batches {num_batches} < optimizer_steps {self._optimizer_steps}." - ) - - seq_ctx_list: list[SequenceContext] = [] - loss_ctx_input_list: list[RLLossContextInputItem] = [] - rollout_logprobs_list: list[torch.Tensor | None] = [] - # convert dummy padding experts to real size - - language_cfg = ( - self.config.model_cfg.text_config - if isinstance(self.config.model_cfg, BaseComposeConfig) - else self.config.model_cfg - ) - - for data in data_batches: - seq_ctx = data["seq_ctx"] - pixel_values = seq_ctx.pixel_values - if pixel_values is not None: - if not isinstance(pixel_values, torch.Tensor): - assert isinstance(pixel_values, list), ( - f"pixel_values should be list of tensor, got {type(pixel_values)}" - ) - pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values] - pixel_values = torch.cat(pixel_values, dim=0) - seq_ctx.pixel_values = pixel_values - - rollout_routed_experts = seq_ctx.rollout_routed_experts - if rollout_routed_experts is not None: - if isinstance(rollout_routed_experts, list): - # list[n,l,e] - out_rollout_routed_expert = [] - for rollout_routed_expert in rollout_routed_experts: - if isinstance(rollout_routed_expert, torch.Tensor): - rollout_routed_experts_tensor = torch.randint( - low=0, - high=language_cfg.n_routed_experts, - size=( - rollout_routed_expert.size(0), - language_cfg.num_hidden_layers, - language_cfg.num_experts_per_tok, - ), - ) - out_rollout_routed_expert.append(rollout_routed_experts_tensor) - else: - rollout_routed_expert_refs = rollout_routed_expert - rollout_routed_expert = ray.get(rollout_routed_expert_refs) - # free obj store explicitly - ray._private.internal_api.free(rollout_routed_expert_refs) - out_rollout_routed_expert.append(torch.as_tensor(rollout_routed_expert, dtype=torch.long)) - - seq_ctx.rollout_routed_experts = torch.cat(out_rollout_routed_expert, dim=0) # max_len,l,e - else: - assert isinstance(rollout_routed_experts, torch.Tensor), ( - f"padding experts should be a dummy tensor, bug got {type(rollout_routed_experts)}" - ) - rollout_routed_experts_tensor = torch.randint( - low=0, - high=language_cfg.n_routed_experts, - size=( - self.config.pack_max_length, - language_cfg.num_hidden_layers, - language_cfg.num_experts_per_tok, - ), - ) - seq_ctx.rollout_routed_experts = rollout_routed_experts_tensor + def _resolve_ray_data(self, seq_ctx: SequenceContext, language_cfg) -> SequenceContext: + pixel_values = seq_ctx.pixel_values + if pixel_values is not None: + if not isinstance(pixel_values, torch.Tensor): + assert isinstance(pixel_values, list), ( + f"pixel_values should be list of tensor, got {type(pixel_values)}" + ) + pixel_values = [ray.get(pixel_obf) for pixel_obf in pixel_values] + pixel_values = torch.cat(pixel_values, dim=0) + seq_ctx.pixel_values = pixel_values + + rollout_routed_experts = seq_ctx.rollout_routed_experts + if rollout_routed_experts is not None: + if isinstance(rollout_routed_experts, list): + # list[n,l,e] + out_rollout_routed_expert = [] + for rollout_routed_expert in rollout_routed_experts: + if isinstance(rollout_routed_expert, torch.Tensor): + rollout_routed_experts_tensor = torch.randint( + low=0, + high=language_cfg.n_routed_experts, + size=( + rollout_routed_expert.size(0), + language_cfg.num_hidden_layers, + language_cfg.num_experts_per_tok, + ), + ) + out_rollout_routed_expert.append(rollout_routed_experts_tensor) + else: + rollout_routed_expert_refs = rollout_routed_expert + rollout_routed_expert = ray.get(rollout_routed_expert_refs) + # free obj store explicitly + ray._private.internal_api.free(rollout_routed_expert_refs) + out_rollout_routed_expert.append(torch.as_tensor(rollout_routed_expert, dtype=torch.long)) + + seq_ctx.rollout_routed_experts = torch.cat(out_rollout_routed_expert, dim=0) # max_len,l,e + else: + assert isinstance(rollout_routed_experts, torch.Tensor), ( + f"padding experts should be a dummy tensor, bug got {type(rollout_routed_experts)}" + ) + rollout_routed_experts_tensor = torch.randint( + low=0, + high=language_cfg.n_routed_experts, + size=( + self.config.pack_max_length, + language_cfg.num_hidden_layers, + language_cfg.num_experts_per_tok, + ), + ) + seq_ctx.rollout_routed_experts = rollout_routed_experts_tensor assert seq_ctx.input_ids is not None, "input_ids is None" assert seq_ctx.rollout_routed_experts.size(0) == seq_ctx.input_ids.size(1) - seq_ctx = data["seq_ctx"].to(DEVICE) - rollout_logprobs = data.get("rollout_logprobs", None) - if rollout_logprobs is not None: - rollout_logprobs = rollout_logprobs.to(DEVICE) - rollout_logprobs_list.append(rollout_logprobs) - loss_ctx_input = RLLossContextInputItem( - shifted_labels=data["shifted_labels"], - advantages=data["advantages"], - rollout_logprobs=rollout_logprobs, - ).to(DEVICE) - if self.sp_mesh.size() > 1: - seq_ctx = seq_ctx.split(self.sp_mesh) - loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) - seq_ctx_list.append(seq_ctx) - loss_ctx_input_list.append(loss_ctx_input) + return seq_ctx - del data_batches - - rank_grad_tokens: torch.Tensor | None = None - for loss_ctx_input in loss_ctx_input_list: - mask = loss_ctx_input.shifted_labels != -100 - grad_tokens = mask.sum() - rank_grad_tokens = grad_tokens if rank_grad_tokens is None else rank_grad_tokens + grad_tokens - rank_grad_tokens = cast(torch.Tensor, rank_grad_tokens) - global_grad_tokens = rank_grad_tokens - dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM) - - # old logprobs are inplaced updated in compute_actor_logprobs - loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) - sum_entropy: torch.Tensor | None = None - sum_rollout_entropy: torch.Tensor | None = None + def _apply_rollout_is_correction( + self, + seq_ctx_list: list[SequenceContext], + loss_ctx_input_list: list[RLLossContextInputItem], + rollout_logprobs_list: list[torch.Tensor | None], + loss_cfg: BaseRLLossConfig, + ) -> tuple[list[RLLossContextInputItem], Dict[str, Any]]: + """Apply importance sampling corrections to the loss context, compute + metrics like entropy, and log them.""" if len(rollout_logprobs_list) > 0: assert len(rollout_logprobs_list) == len(loss_ctx_input_list), ( f"rollout_logprobs_list {len(rollout_logprobs_list)} vs loss_ctx_input_list {len(loss_ctx_input_list)}" ) + sum_entropy: torch.Tensor | None = None + sum_rollout_entropy: torch.Tensor | None = None all_rollout_is_metrics = [] all_mismatch_metrics = [] + for i, loss_ctx_input in enumerate(loss_ctx_input_list): mask = loss_ctx_input.shifted_labels != -100 entropy = -(cast(torch.Tensor, loss_ctx_input.old_logprobs) * mask).sum() @@ -543,10 +501,12 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo cu_seq_lens = seq_ctx_list[i].cu_seq_lens_q num_tokens = cu_seq_lens[1:] - cu_seq_lens[:-1] + old_log_prob = cast(torch.Tensor, loss_ctx_input.old_logprobs) + rollout_log_prob = cast(torch.Tensor, rollout_logprobs_list[i]) rollout_is_weights, rollout_is_mask, mismatch_metrics, rollout_is_metrics = ( loss_cfg.rollout_is.compute_rollout_importance_weights_and_metrics( - old_log_prob=loss_ctx_input.old_logprobs, - rollout_log_prob=rollout_logprobs_list[i], + old_log_prob=old_log_prob, + rollout_log_prob=rollout_log_prob, num_tokens=num_tokens, response_mask=mask, ) @@ -556,37 +516,107 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo all_rollout_is_metrics.append(rollout_is_metrics) all_mismatch_metrics.append(mismatch_metrics) + metrics = { + "sum_entropy": sum_entropy, + "sum_rollout_entropy": sum_rollout_entropy, + "all_mismatch_metrics": all_mismatch_metrics, + "all_rollout_is_metrics": all_rollout_is_metrics, + } + return loss_ctx_input_list, metrics + + @ray_method + def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int): + # NOTE: sglang会清除logger handle, 重新创建 + self.logger = get_logger(log_dir=self.log_dir, tag="TrainingWorker") + loss_cfg = self.config.loss_cfg + + num_batches = len(data_batches) + assert num_batches == self._optimizer_steps, ( + f"Data batches length {num_batches} must be equal to optimizer_steps {self._optimizer_steps}." + ) + packd_batch_num_per_step = [] + seq_ctx_list: list[SequenceContext] = [] + loss_ctx_input_list: list[RLLossContextInputItem] = [] + rollout_logprobs_list: list[torch.Tensor | None] = [] + language_cfg = ( + self.config.model_cfg.text_config + if isinstance(self.config.model_cfg, BaseComposeConfig) + else self.config.model_cfg + ) + + for step_idx, step_data_batches in enumerate(data_batches): + # number of packed batch num means the gradient accumulation steps + packd_batch_num_per_step.append(len(step_data_batches)) + for data in step_data_batches: + seq_ctx = self._resolve_ray_data(data["seq_ctx"], language_cfg) + rollout_logprobs = data.get("rollout_logprobs") + loss_ctx_input = RLLossContextInputItem( + shifted_labels=data["shifted_labels"], + advantages=data["advantages"], + rollout_logprobs=rollout_logprobs, + ) + seq_ctx = seq_ctx.to(DEVICE) + loss_ctx_input = loss_ctx_input.to(DEVICE) + if self.sp_mesh.size() > 1: + seq_ctx = seq_ctx.split(self.sp_mesh) + loss_ctx_input = loss_ctx_input.sp_split(self.sp_mesh) + seq_ctx_list.append(seq_ctx) + loss_ctx_input_list.append(loss_ctx_input) + rollout_logprobs_list.append(loss_ctx_input.rollout_logprobs) + + del data_batches + + # old logprobs are inplaced updated in compute_actor_logprobs + # TODO: overlap compute_actor_logprobs with _resolve_ray_data + loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) + loss_ctx_input_list, metrics = self._apply_rollout_is_correction( + seq_ctx_list, loss_ctx_input_list, rollout_logprobs_list, loss_cfg + ) + + rank_grad_tokens = sum((loss_ctx.shifted_labels != -100).sum() for loss_ctx in loss_ctx_input_list) + rank_grad_tokens = cast(torch.Tensor, rank_grad_tokens).to(DEVICE) + global_grad_tokens = rank_grad_tokens.clone() + dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM) + worker_log_item: WorkerLogItem = {"train_entropy": 0.0, "train_metrics": [], "sft_train_metrics": {}} - logger_msg = f"Rollout {rollout_idx}: " - sum_entropy = cast(torch.Tensor, sum_entropy) + log_parts = [] + + sum_entropy = cast(torch.Tensor, metrics["sum_entropy"]) dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM) avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0) worker_log_item["train_entropy"] = avg_sum_entropy.item() - logger_msg += f"avg entropy: {avg_sum_entropy:.4f}" + log_parts.append(f"avg entropy: {avg_sum_entropy:.4f}") + sum_rollout_entropy = metrics.get("sum_rollout_entropy") if sum_rollout_entropy is not None: - sum_rollout_entropy = cast(torch.Tensor, sum_rollout_entropy) dist.all_reduce(sum_rollout_entropy, op=dist.ReduceOp.SUM) avg_rollout_entropy = ( sum_rollout_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0) ) worker_log_item["rollout_entropy"] = avg_rollout_entropy.item() - logger_msg += f", avg rollout entropy: {avg_rollout_entropy:.4f}" + log_parts.append(f"avg rollout entropy: {avg_rollout_entropy:.4f}") + all_mismatch_metrics = metrics.get("all_mismatch_metrics", []) if len(all_mismatch_metrics) > 0: mismatch_metrics = merge_rollout_is_metrics(all_mismatch_metrics, DEVICE) if len(mismatch_metrics) > 0: worker_log_item["mismatch_metrics"] = mismatch_metrics - logger_msg += f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}" + log_parts.append(f"\n rollout mismatch metrics:\n{json.dumps(mismatch_metrics, indent=4)}") + all_rollout_is_metrics = metrics.get("all_rollout_is_metrics", []) if len(all_rollout_is_metrics) > 0: rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE) if len(rollout_is_metrics) > 0: worker_log_item["rollout_is_metrics"] = rollout_is_metrics - logger_msg += f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}" + log_parts.append( + f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}" + ) - if self.rank == 0: - self.logger.info(logger_msg) + logger_msg = f"Rollout {rollout_idx}: " + ", ".join(part for part in log_parts if not part.startswith("\n")) + for part in log_parts: + if part.startswith("\n"): + logger_msg += part + self.logger.info(logger_msg) if self._has_ref: # ref logprobs are inplaced updated in compute_actor_logprobs @@ -607,9 +637,13 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo avg_kl_div = kl_div_sum / global_grad_tokens if global_grad_tokens > 0 else 0 self.logger.info(f"Rollout {rollout_idx}: avg KL divergence: {avg_kl_div:.4f}") - for i in range(0, len(seq_ctx_list), iters_per_step): - batches_seq_ctx = seq_ctx_list[i : i + iters_per_step] - batches_loss_ctx_input = loss_ctx_input_list[i : i + iters_per_step] + start_idx = 0 + for i in range(self._optimizer_steps): + num_packs_this_step = packd_batch_num_per_step[i] + end_idx = start_idx + num_packs_this_step + batches_seq_ctx = seq_ctx_list[start_idx:end_idx] + batches_loss_ctx_input = loss_ctx_input_list[start_idx:end_idx] + start_idx = end_idx LossContext = loss_cfg.loss_ctx_cls batches_loss_kwargs = LossContext.build_batches_loss_kwargs(batches_loss_ctx_input, loss_cfg) @@ -640,7 +674,10 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int) -> WorkerLo f"{key}={value:.4f}" if isinstance(value, float) else f"{key}={value}" for key, value in log_info.items() ) - log_str = f"Rank{self.rank} Rollout {rollout_idx} Step {i}: " + log_str + log_str = ( + f"Rank{self.rank} Rollout {rollout_idx} Step {i}: gradient_accumulation_steps={num_packs_this_step}, " + + log_str + ) self.logger.info(log_str) self._rollout_step += 1 @@ -793,6 +830,10 @@ def get_model_cfg(self): model_cfg = self._engine.model_cfg return model_cfg + @ray_method + def get_worker_cfg(self): + return self.config + @ray_method def offload_model(self): self._engine.put_model_to_device("cpu") diff --git a/xtuner/v1/rl/pack.py b/xtuner/v1/rl/pack.py new file mode 100644 index 000000000..c042ff17a --- /dev/null +++ b/xtuner/v1/rl/pack.py @@ -0,0 +1,399 @@ +import math +from pathlib import Path +from typing import cast + +import numpy as np +import torch + +from xtuner.v1.data_proto.sequence_context import SequenceContext +from xtuner.v1.model.base import TransformerConfig +from xtuner.v1.model.compose.base import BaseComposeConfig +from xtuner.v1.rl.utils import get_seqlen_balanced_partitions +from xtuner.v1.utils import get_logger + +from .base import WorkerInputItem + + +class DataBatchPacker: + def __init__( + self, + pack_max_length: int, + world_size: int, + data_replicate_size: int, + optimizer_steps: int, + pack_strategy: str = "greedy", + model_cfg: TransformerConfig | None = None, + worker_log_dir: str | None = None, + ): + self.pack_max_length = pack_max_length + self.world_size = world_size + self.data_replicate_size = data_replicate_size + self.optimizer_steps = optimizer_steps + if worker_log_dir is not None: + self.worker_log_dir = Path(worker_log_dir) if isinstance(worker_log_dir, str) else worker_log_dir + self.logger = get_logger(log_dir=self.worker_log_dir, tag="TrainingController") + else: + self.logger = get_logger() + + self.data_batch_properties = { + "is_qwen3_vl": False, + "has_rollout_routed_experts": False, + "has_rollout_logprobs": False, + "n_routed_experts": None, + } + self.strategy_map = {"greedy": self.greedy_pack, "balanced": self.balance_pack, "native": self.native_pack} + if pack_strategy not in self.strategy_map: + raise ValueError(f"Unknown packing strategy: {pack_strategy}") + self._pack_impl = self.strategy_map[pack_strategy] + self.dp_size = self.world_size // self.data_replicate_size + self.padding_tokens = 0 + self.model_cfg = model_cfg + + def pack(self, data_batches: list[WorkerInputItem]) -> tuple[list[list[list[WorkerInputItem]]], int]: + self.padding_tokens = 0 + if not data_batches: + return [], 0 + self._set_data_batch_properties(data_batches) + return self._pack_impl(data_batches), self.padding_tokens + + def greedy_pack(self, data_batches: list[WorkerInputItem]) -> list[list[list[WorkerInputItem]]]: + # 策略核心:贪心打包 + # 1. 使用贪心算法将所有样本打包成一个一维的 pack 列表。 + # 此过程不考虑 DP 和优化步骤,目标是尽可能填满每个 pack。 + pack_infos = self._get_pack_infos( + data_batches, + [data["seq_ctx"].input_ids.numel() for data in data_batches], # type: ignore[union-attr] + self.pack_max_length, + ) + total_data_batches: list[WorkerInputItem] = [] + + # 2. 遍历打包信息,将每个 pack 内的样本拼接并填充到 pack_max_length。 + for pack_info in pack_infos: + indices = pack_info["indices"] + batch4pack = [data_batches[i] for i in indices] + packed_item = self._pad_and_pack_batches(batch4pack, self.pack_max_length) + total_data_batches.append(packed_item) + + # 3. 为了均匀分配,填充整个 batch,使其总 pack 数能被 dp_size 整除。 + dp_size = self.world_size // self.data_replicate_size + num_packed_data_batches = len(total_data_batches) + pad_num = math.ceil(num_packed_data_batches / dp_size) * dp_size - num_packed_data_batches + if pad_num > 0: + padding_items = self._create_padding_item(self.pack_max_length, self.pack_max_length) + pad_data_samples = [padding_items for _ in range(pad_num)] + self.padding_tokens += pad_num * self.pack_max_length + total_data_batches = total_data_batches + pad_data_samples + + # 4. 将填充后的 pack 列表按 dp_size 和 optimizer_steps 重新分配。 + each_dp_batches_num = len(total_data_batches) // dp_size + iters_per_step = math.ceil(each_dp_batches_num // self.optimizer_steps) + actual_optimizer_steps = math.ceil(each_dp_batches_num // iters_per_step) + packed_data_batches: list[list[list[WorkerInputItem]]] = [ + [[] for _ in range(actual_optimizer_steps)] for _ in range(dp_size) + ] + for dp_rank in range(dp_size): + for step in range(actual_optimizer_steps): + start_idx = dp_rank * each_dp_batches_num + step * iters_per_step + end_idx = min(start_idx + iters_per_step, each_dp_batches_num * (dp_rank + 1)) + packed_data_batches[dp_rank][step] = total_data_batches[start_idx:end_idx] + return packed_data_batches + + def balance_pack(self, data_batches: list[WorkerInputItem]) -> list[list[list[WorkerInputItem]]]: + # 策略核心:层层 token 均衡 + # 目标是让每个 DP rank 在每个 optimizer_step 中处理的 token 数都尽可能接近。 + packed_data_batches: list[list[list[WorkerInputItem]]] = [ + [[] for _ in range(self.optimizer_steps)] for _ in range(self.dp_size) + ] + # 1. 按照 dp_size 对数据进行重新分配,保证每个 dp rank 上的 token 数量大致相同 + batches_per_dp_group: list[list[WorkerInputItem]] = self._balance_split_batch(data_batches, self.dp_size) + max_packs_per_step = [0] * self.optimizer_steps + + for dp_rank, dp_worker_data_batches in enumerate(batches_per_dp_group): + # 2. 在每个 DP 组内部,根据 token 数将数据均衡地分给 optimizer_steps 个 mini-batch。 + mini_batch_for_steps: list[list[WorkerInputItem]] = self._balance_split_batch( + dp_worker_data_batches, self.optimizer_steps + ) + for step_idx, step_mini_batch in enumerate(mini_batch_for_steps): + # 3. 第三次均衡:在每个 mini-batch 内部,再次进行均衡打包,并记录每个 step 的最大 pack 数。 + self._pack_mini_batches_for_each_optimizer_step( + packed_data_batches, step_mini_batch, dp_rank, step_idx, self.pack_max_length + ) + if len(packed_data_batches[dp_rank][step_idx]) > max_packs_per_step[step_idx]: + max_packs_per_step[step_idx] = len(packed_data_batches[dp_rank][step_idx]) + + self.logger.info(f"Gradient accumulation for each optimizer steps: {max_packs_per_step}") + + # 4. 最终填充:根据记录的最大 pack 数,将所有 DP rank 在每个 step 的 pack 数量填充至一致。 + for step_idx in range(self.optimizer_steps): + max_packs = max_packs_per_step[step_idx] + packed_data_batches = self._pad_to_max_packs_across_workes( + packed_data_batches, step_idx, max_packs, self.pack_max_length + ) + return packed_data_batches + + def native_pack(self, data_batches: list[WorkerInputItem]) -> list[list[list[WorkerInputItem]]]: + # 策略核心:按样本数量朴素切分,保证样本顺序 + # 这种方法不考虑 token 长度,仅保证每个 DP rank 和 optimizer_step 分到的样本数量大致相等。 + packed_data_batches: list[list[list[WorkerInputItem]]] = [ + [[] for _ in range(self.optimizer_steps)] for _ in range(self.dp_size) + ] + batches_per_dp_group: list[list[WorkerInputItem]] = np.array_split(data_batches, self.dp_size) + max_packs_per_step = [0] * self.optimizer_steps + + for dp_rank, dp_worker_data_batches in enumerate(batches_per_dp_group): + mini_batch_for_steps: list[list[WorkerInputItem]] = np.array_split( + dp_worker_data_batches, self.optimizer_steps + ) + for step_idx, step_mini_batch in enumerate(mini_batch_for_steps): + self._pack_mini_batches_for_each_optimizer_step( + packed_data_batches, step_mini_batch, dp_rank, step_idx, self.pack_max_length + ) + if len(packed_data_batches[dp_rank][step_idx]) > max_packs_per_step[step_idx]: + max_packs_per_step[step_idx] = len(packed_data_batches[dp_rank][step_idx]) + + self.logger.info(f"Gradient accumulation for each optimizer steps: {max_packs_per_step}") + + # padding for each worker to have same number of packs in each optimizer step + for step_idx in range(self.optimizer_steps): + max_packs = max_packs_per_step[step_idx] + packed_data_batches = self._pad_to_max_packs_across_workes( + packed_data_batches, step_idx, max_packs, self.pack_max_length + ) + return packed_data_batches + + def _get_pack_infos(self, dataset, num_tokens, target, random=None): + inds = list(range(len(dataset))) + if random is not None: + random.shuffle(inds) + + item_buffer = [] + length_buffer = [] + longest = 0 + + pack_infos = [] + for shfl_i in inds: + if num_tokens[shfl_i] + sum(length_buffer) <= target: + item_buffer.append(shfl_i) + length_buffer.append(num_tokens[shfl_i]) + longest = max(longest, num_tokens[shfl_i]) + else: + if len(item_buffer) > 0: + info = { + "indices": item_buffer, + "longest": int(longest), + } + pack_infos.append(info) + + item_buffer = [shfl_i] + length_buffer = [num_tokens[shfl_i]] + longest = num_tokens[shfl_i] + if len(item_buffer) > 0: + info = { + "indices": item_buffer, + "longest": int(longest), + } + + pack_infos.append(info) + return pack_infos + + def _balance_split_batch(self, data_batches: list[WorkerInputItem], partition_size) -> list[list[WorkerInputItem]]: + """Reorder the data on single controller such that each dp rank gets + similar total tokens.""" + global_seqlen_lst = [data["seq_ctx"].input_ids.numel() for data in data_batches] # type: ignore[union-attr] + global_partition_lst = get_seqlen_balanced_partitions( + global_seqlen_lst, k_partitions=partition_size, equal_size=True + ) + balanced_batches = [] + tokens_in_partition = [] + for partition in global_partition_lst: + partition_batch = [data_batches[i] for i in partition] + tokens_in_partition.append(sum(data["seq_ctx"].input_ids.numel() for data in partition_batch)) + balanced_batches.append(partition_batch) + get_logger().info(f"Balanced split into {partition_size} partitions with tokens: {tokens_in_partition}") + return balanced_batches + + def _set_data_batch_properties(self, data_batches: list[WorkerInputItem]): + if not data_batches: + return + + first_item = data_batches[0] + seq_ctx = first_item["seq_ctx"] + + self.data_batch_properties["is_qwen3_vl"] = ( + seq_ctx.position_ids is not None and len(seq_ctx.position_ids.shape) == 3 + ) + self.data_batch_properties["has_rollout_logprobs"] = ( + "rollout_logprobs" in first_item and first_item["rollout_logprobs"] is not None + ) + self.data_batch_properties["has_rollout_routed_experts"] = seq_ctx.rollout_routed_experts is not None + + language_cfg = None + if self.data_batch_properties["has_rollout_routed_experts"]: + language_cfg = self.model_cfg + if isinstance(self.model_cfg, BaseComposeConfig): + language_cfg = self.model_cfg.text_config + + self.data_batch_properties["n_routed_experts"] = ( + language_cfg.n_routed_experts if language_cfg is not None else None + ) + self.logger.info(f"Data batch properties set: {self.data_batch_properties}") + + def _pad_and_pack_batches(self, batch4pack: list[WorkerInputItem], pack_max_length: int) -> WorkerInputItem: + seq_ctx_list = [item["seq_ctx"] for item in batch4pack] + label_list = [item["shifted_labels"] for item in batch4pack] + advantage_list = [] + for item in batch4pack: + advantages = item["advantages"].reshape(1, -1) + advantage_list.append(advantages) + rollout_logprobs_list = [ + item["rollout_logprobs"] if self.data_batch_properties["has_rollout_logprobs"] else None + for item in batch4pack + ] + cur_length = 0 + for item in batch4pack: + cur_length += item["seq_ctx"].input_ids.numel() # type: ignore[union-attr] + padding_len = pack_max_length - cur_length + + if padding_len > 0: + padding_item = self._create_padding_item(padding_len, pack_max_length) + self.padding_tokens += padding_len + seq_ctx_list.append(padding_item["seq_ctx"]) + label_list.append(padding_item["shifted_labels"]) + advantage_list.append(padding_item["advantages"]) + rollout_logprobs_list.append(padding_item["rollout_logprobs"]) + + packed_seq_ctx = SequenceContext.pack(seq_ctx_list) + packed_shifted_labels = torch.cat(label_list, dim=1) # type: ignore[arg-type] + packed_shifted_labels = cast(torch.LongTensor, packed_shifted_labels) + cu_seq_lens_q = packed_seq_ctx.cu_seq_lens_q + packed_num_tokens = cu_seq_lens_q[1:] - cu_seq_lens_q[:-1] + packed_advantages = torch.cat(advantage_list, dim=1) + packed_advantages = torch.repeat_interleave(packed_advantages, packed_num_tokens, dim=1) + if self.data_batch_properties["has_rollout_logprobs"]: + cast_rollout_logprobs_list = [cast(torch.Tensor, item) for item in rollout_logprobs_list] + packed_rollout_logprobs = torch.cat(cast_rollout_logprobs_list, dim=1) + else: + packed_rollout_logprobs = None + + optimizer_step_packs: WorkerInputItem = { + "seq_ctx": packed_seq_ctx, + "shifted_labels": packed_shifted_labels, + "advantages": packed_advantages, + "rollout_logprobs": packed_rollout_logprobs, + } + return optimizer_step_packs + + def _pad_to_max_packs_across_workes( + self, + packed_data_batches: list[list[list[WorkerInputItem]]], + step_idx: int, + max_packs: int, + pack_max_length: int, + ): + for dp_rank in range(len(packed_data_batches)): + num_current_packs = len(packed_data_batches[dp_rank][step_idx]) + num_padding_packs = max_packs - num_current_packs + + if num_padding_packs > 0: + padding_item = self._create_padding_item(pack_max_length, pack_max_length) + self.padding_tokens += num_padding_packs * pack_max_length + padding_items = [padding_item for _ in range(num_padding_packs)] + packed_data_batches[dp_rank][step_idx].extend(padding_items) + return packed_data_batches + + def _pack_mini_batches_for_each_optimizer_step( + self, + packed_data_batches: list[list[list[WorkerInputItem]]], + step_mini_batches: list[WorkerInputItem], + dp_rank: int, + step_idx: int, + pack_max_length: int, + ): + seqlen_list = [] + for data in step_mini_batches: + assert data["seq_ctx"].input_ids.numel() <= pack_max_length, ( # type: ignore[union-attr] + f"Single sample seq len {data['seq_ctx'].input_ids.numel()} exceeds pack_max_length {pack_max_length}" # type: ignore[union-attr] + ) + seqlen_list.append(data["seq_ctx"].input_ids.numel()) # type: ignore[union-attr] + total_length = sum(seqlen_list) + + batch_list_for_pack: list[list[WorkerInputItem]] = [] + if total_length > pack_max_length: + # balance mini batches across gradient accumulation steps + num_packs = math.ceil(total_length / pack_max_length) + partitions_indices = get_seqlen_balanced_partitions( + seqlen_list=seqlen_list, k_partitions=num_packs, equal_size=False + ) + for partition in partitions_indices: + batch_list = [step_mini_batches[i] for i in partition] + batch_list_for_pack.append(batch_list) + else: + batch_list_for_pack = [step_mini_batches] + + for batch4pack in batch_list_for_pack: + # pad and pack batches into a single optimizer step pack + step_pack = self._pad_and_pack_batches(batch4pack, pack_max_length) + packed_data_batches[dp_rank][step_idx].append(step_pack) + + def _create_padding_item( + self, + pad_len: int, + pack_max_length: int, + split_size: int = 1024, + ) -> WorkerInputItem: + # padding input_ids + pad_tokens = tuple( + torch.zeros(1, split_size, dtype=torch.long, device="cpu") for _ in range(pad_len // split_size) + ) + if pad_len % split_size > 0: + pad_tokens = pad_tokens + (torch.zeros(1, pad_len % split_size, dtype=torch.long, device="cpu"),) + pad_tokens = cast(tuple[torch.LongTensor, ...], pad_tokens) + pad_seq_ctx = SequenceContext.from_input_ids(pad_tokens, device="cpu") + pad_seq_ctx.num_padding = pad_len + + # padding mm positions_ids + if self.data_batch_properties["is_qwen3_vl"]: + _position_ids_list = [] + for pad_token in pad_tokens: + _position_ids = torch.arange(pad_token.size(-1)).view(1, 1, -1).expand(3, 1, -1) + _position_ids_list.append(_position_ids) + position_ids = torch.cat(_position_ids_list, dim=-1) + position_ids = cast(torch.LongTensor, position_ids) + pad_seq_ctx.position_ids = position_ids + + # padding rollout routed experts + if self.data_batch_properties["has_rollout_routed_experts"]: + assert self.data_batch_properties["n_routed_experts"], ( + "n_routed_experts must be provided when has_rollout_routed_experts is True" + ) + if pad_len == pack_max_length: + pad_rand_index = torch.randint( + low=0, high=1, size=(1, 1, 1) + ) # add dummy data, true data will be initialized in train worker.fit + else: + pad_rand_index = torch.randint( + low=0, high=self.data_batch_properties["n_routed_experts"], size=(pad_len, 1, 1) + ) + pad_seq_ctx.rollout_routed_experts = pad_rand_index + + pad_labels = cast(torch.LongTensor, torch.full((1, pad_len), -100, dtype=torch.int64, device="cpu")) + pad_advantage_length = pack_max_length if pad_len == pack_max_length else math.ceil(pad_len / split_size) + pad_advantage = torch.full( + (1, pad_advantage_length), + -100, + dtype=torch.float32, + device="cpu", + ) + pad_rollout_logprobs = ( + torch.zeros(1, pad_len, dtype=torch.float32, device="cpu") + if self.data_batch_properties["has_rollout_logprobs"] + else None + ) + + padding_item: WorkerInputItem = { + "seq_ctx": pad_seq_ctx, + "shifted_labels": pad_labels, + "advantages": pad_advantage, + "rollout_logprobs": pad_rollout_logprobs, + } + return padding_item diff --git a/xtuner/v1/rl/utils.py b/xtuner/v1/rl/utils.py index 8958f1736..c58d0acbf 100644 --- a/xtuner/v1/rl/utils.py +++ b/xtuner/v1/rl/utils.py @@ -1,4 +1,5 @@ import atexit +import heapq import signal import subprocess from typing import Any @@ -72,3 +73,152 @@ def signal_handler(signum, frame): signal.signal(signal.SIGTERM, signal_handler) atexit.register(cleanup_once) + + +# Adapted from https://github.com/volcengine/verl/blob/eb6991a622e15c494ee8403e2289708b2a3b278f/verl/utils/seqlen_balancing.py#L37 +def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): + # see: https://en.wikipedia.org/wiki/Largest_differencing_method + class Set: + def __init__(self) -> None: + self.sum = 0 + self.items: list[tuple[int, int]] = [] + + def add(self, idx: int, val: int): + self.items.append((idx, val)) + self.sum += val + + def merge(self, other): + for idx, val in other.items: + self.items.append((idx, val)) + self.sum += val + + def __lt__(self, other): + if self.sum != other.sum: + return self.sum < other.sum + if len(self.items) != len(other.items): + return len(self.items) < len(other.items) + return self.items < other.items + + class State: + def __init__(self, items: list[tuple[int, int]], k: int) -> None: + self.k = k + # sets should always be decreasing order + self.sets = [Set() for _ in range(k)] + assert len(items) in [1, k], f"{len(items)} not in [1, {k}]" + for i, (idx, seqlen) in enumerate(items): + self.sets[i].add(idx=idx, val=seqlen) + self.sets = sorted(self.sets, reverse=True) + + def get_partitions(self): + partitions = [] + for i in range(len(self.sets)): + cur_partition = [] + for idx, _ in self.sets[i].items: + cur_partition.append(idx) + partitions.append(cur_partition) + return partitions + + def merge(self, other): + for i in range(self.k): + self.sets[i].merge(other.sets[self.k - 1 - i]) + self.sets = sorted(self.sets, reverse=True) + + @property + def spread(self) -> int: + return self.sets[0].sum - self.sets[-1].sum + + def __lt__(self, other): + # least heap, let the state with largest spread to be popped first, + # if the spread is the same, let the state who has the largest set + # to be popped first. + if self.spread != other.spread: + return self.spread > other.spread + return self.sets[0] > other.sets[0] + + def __repr__(self) -> str: + repr_str = "[" + for i in range(self.k): + if i > 0: + repr_str += "," + repr_str += "{" + for j, (_, seqlen) in enumerate(self.sets[i].items): + if j > 0: + repr_str += "," + repr_str += str(seqlen) + repr_str += "}" + repr_str += "]" + return repr_str + + sorted_seqlen_list = sorted([(seqlen, i) for i, seqlen in enumerate(seqlen_list)]) + states_pq: list[State] = [] + if equal_size: + assert len(seqlen_list) % k_partitions == 0, f"{len(seqlen_list)} % {k_partitions} != 0" + for offset in range(0, len(sorted_seqlen_list), k_partitions): + items = [] + for i in range(k_partitions): + seqlen, idx = sorted_seqlen_list[offset + i] + items.append((idx, seqlen)) + heapq.heappush(states_pq, State(items=items, k=k_partitions)) + else: + for seqlen, idx in sorted_seqlen_list: + heapq.heappush(states_pq, State(items=[(idx, seqlen)], k=k_partitions)) + + while len(states_pq) > 1: + state0 = heapq.heappop(states_pq) + state1 = heapq.heappop(states_pq) + # merge states + state0.merge(state1) + heapq.heappush(states_pq, state0) + + final_state = states_pq[0] + partitions = final_state.get_partitions() + if equal_size: + for i, partition in enumerate(partitions): + assert len(partition) * k_partitions == len(seqlen_list), ( + f"{len(partition)} * {k_partitions} != {len(seqlen_list)}" + ) + return partitions + + +def get_seqlen_balanced_partitions(seqlen_list: list[int], k_partitions: int, equal_size: bool): + """Calculates partitions of indices from seqlen_list such that the sum of + sequence lengths in each partition is balanced. Uses the Karmarkar-Karp + differencing method. + + This is useful for balancing workload across devices or batches, especially when + dealing with variable sequence lengths. + + Args: + seqlen_list (List[int]): A list of sequence lengths for each item. + k_partitions (int): The desired number of partitions. + equal_size (bool): If True, ensures that each partition has the same number of items. + Requires len(seqlen_list) to be divisible by k_partitions. + If False, partitions can have varying numbers of items, focusing + only on balancing the sum of sequence lengths. + + Returns: + List[List[int]]: A list containing k_partitions lists. Each inner list contains the + original indices of the items assigned to that partition. The indices + within each partition list are sorted. + + Raises: + AssertionError: If len(seqlen_list) < k_partitions. + AssertionError: If equal_size is True and len(seqlen_list) is not divisible by k_partitions. + AssertionError: If any resulting partition is empty. + """ + assert len(seqlen_list) >= k_partitions, f"number of items:[{len(seqlen_list)}] < k_partitions:[{k_partitions}]" + + def _check_and_sort_partitions(partitions): + assert len(partitions) == k_partitions, f"{len(partitions)} != {k_partitions}" + seen_idx = set() + sorted_partitions = [None] * k_partitions + for i, partition in enumerate(partitions): + assert len(partition) > 0, f"the {i}-th partition is empty" + for idx in partition: + seen_idx.add(idx) + sorted_partitions[i] = sorted(partition) + assert seen_idx == set(range(len(seqlen_list))) + return sorted_partitions + + partitions = karmarkar_karp(seqlen_list=seqlen_list, k_partitions=k_partitions, equal_size=equal_size) + return _check_and_sort_partitions(partitions) diff --git a/xtuner/v1/train/rl_trainer.py b/xtuner/v1/train/rl_trainer.py index b606638e9..158ab9a8a 100644 --- a/xtuner/v1/train/rl_trainer.py +++ b/xtuner/v1/train/rl_trainer.py @@ -1,6 +1,5 @@ import json import os -import random from datetime import datetime from pathlib import Path from shutil import rmtree @@ -29,10 +28,11 @@ from xtuner.v1.rl.base import ( TrainingController, TrainingControllerProxy, + TrainingLogInfo, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, - WorkerLogItem, + WorkerInputItem, ) from xtuner.v1.rl.base import TrainingWorker as BaseTrainingWorker from xtuner.v1.train import ResumeConfig @@ -555,13 +555,25 @@ def _train_step(self, rollout_idx: int, data_groups, multimodal_train_infos, ste ) with timer("training", step_timer_dict): - workers_log_item: List[WorkerLogItem] = ray.get( + traning_log_info: TrainingLogInfo = ray.get( self._train_controller.fit.remote( data_batches, pack_max_length=self._train_worker_cfg.pack_max_length, rollout_idx=rollout_idx ) ) - self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) + workers_log_item = traning_log_info["worker_log_infos"] + self._writer.add_scalar(tag="time/training", scalar_value=step_timer_dict["training"], global_step=rollout_idx) + self._writer.add_scalar( + tag="time/pack_time", scalar_value=traning_log_info["pack_time"], global_step=rollout_idx + ) + self._writer.add_scalar( + tag="time/train_time", scalar_value=traning_log_info["train_time"], global_step=rollout_idx + ) + self._writer.add_scalar( + tag="train_metrics/padding_tokens", + scalar_value=traning_log_info["padding_tokens"], + global_step=rollout_idx, + ) rank0_log_item = workers_log_item[0] # These metrics are already aggregated across distributed workers and logging only the metrics from rank 0. rank0_rollout_is_metrics = rank0_log_item.get("rollout_is_metrics") @@ -762,10 +774,10 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf rollout_logprobs = None seq_ctx = get_train_seq_ctx(input_ids, multimodal_train_info, len(response_ids) - 1) - data_dict = { + data_dict: WorkerInputItem = { "seq_ctx": seq_ctx, "shifted_labels": shifted_labels, - "advantage": advantages[i].item(), + "advantages": advantages[i], "rollout_logprobs": rollout_logprobs, } @@ -774,7 +786,6 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf seq_ctx.rollout_routed_experts = routed_experts # n,layer,expert data_batches.append(data_dict) - random.shuffle(data_batches) rewards_t = torch.tensor(rewards_list).float() if rewards_list else torch.tensor([0.0]).float() advantages_t = torch.tensor(advantages_list).float() if advantages_list else torch.tensor([0.0]).float()