Skip to content

Commit 4ef4b9d

Browse files
committed
[Feat] Support DataBatchPacker for RLTrainer
1 parent 8cad75e commit 4ef4b9d

File tree

3 files changed

+730
-0
lines changed

3 files changed

+730
-0
lines changed

tests/ray/test_pack.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import unittest
2+
import torch
3+
from xtuner.v1.data_proto.sequence_context import SequenceContext
4+
from xtuner.v1.rl.base.pack import DataBatchPacker
5+
6+
class TestDataBatchPacker(unittest.TestCase):
7+
def setUp(self):
8+
self.pack_max_length = 3072
9+
self.split_size = 1024
10+
11+
def _create_dummy_item(self, length: int, val=1):
12+
input_ids = torch.full((1, length), val, dtype=torch.long)
13+
cu_seq_lens_q = torch.tensor([0, length], dtype=torch.int32)
14+
cu_seq_lens_k = torch.tensor([0, length], dtype=torch.int32)
15+
max_length_q = torch.tensor(length, dtype=torch.int32)
16+
max_length_k = torch.tensor(length, dtype=torch.int32)
17+
seq_ctx = SequenceContext(
18+
input_ids=input_ids,
19+
cu_seq_lens_q=cu_seq_lens_q,
20+
cu_seq_lens_k=cu_seq_lens_k,
21+
max_length_q=max_length_q,
22+
max_length_k=max_length_k,
23+
num_padding=0,
24+
device="cpu",
25+
)
26+
return {
27+
"seq_ctx": seq_ctx,
28+
"shifted_labels": torch.full((1, length), val, dtype=torch.long),
29+
"advantages": torch.full((1, length), float(val), dtype=torch.float),
30+
"rollout_logprobs": torch.full((1, length), float(val), dtype=torch.float),
31+
}
32+
33+
def _run_strategy_test(self, strategy, world_size, optimizer_steps, lengths, expected_padding):
34+
data_batches = [self._create_dummy_item(l, val=7) for l in lengths]
35+
total_data_tokens = sum(lengths)
36+
37+
packer = DataBatchPacker(
38+
pack_max_length=self.pack_max_length,
39+
world_size=world_size,
40+
data_replicate_size=1,
41+
optimizer_steps=optimizer_steps,
42+
pack_strategy=strategy
43+
)
44+
45+
packed_res, padding_tokens = packer.pack(data_batches)
46+
47+
self.assertEqual(padding_tokens, expected_padding,
48+
f"Strategy {strategy} padding mismatch. Expected {expected_padding}, got {padding_tokens}")
49+
50+
all_packs = []
51+
for rank_data in packed_res:
52+
for step_data in rank_data:
53+
all_packs.extend(step_data)
54+
55+
total_capacity = len(all_packs) * self.pack_max_length
56+
self.assertEqual(total_capacity, total_data_tokens + padding_tokens)
57+
58+
valid_token_count = sum((p["seq_ctx"].input_ids != 0).sum().item() for p in all_packs)
59+
valid_label_count = sum((p["shifted_labels"] != -100).sum().item() for p in all_packs)
60+
valid_adv_count = sum((p["advantages"] != -100).sum().item() for p in all_packs)
61+
62+
self.assertEqual(valid_token_count, total_data_tokens)
63+
self.assertEqual(valid_label_count, total_data_tokens)
64+
self.assertEqual(valid_adv_count, total_data_tokens)
65+
66+
def test_variable_packs(self):
67+
"""随机tokens数输入, dp=2, optimizer_steps=2
68+
- Native:
69+
1. DP Rank 切分:
70+
rank0: [1500, 1000, 2800, 1500]
71+
rank1: [2000, 2100, 1000, 800]
72+
2. Optimizer steps切分:
73+
rank0: [1500, 1000], [2800, 1500]
74+
rank1: [2000, 2100], [1000, 800]
75+
3 pack and padding
76+
rank0: step0: [2500 -> 3072], step1: [2800 -> 3072], [1500 -> 3072],
77+
rank1: step0: [2100 -> 3072], [2000 -> 3072], step1: [1800 -> 3072]
78+
4. 跨卡对齐pack数量:
79+
rank0: step0: [2500 -> 3072], [0 -> 3072] step1: [2800 -> 3072], [1500 -> 3072],
80+
rank1: step0: [2100 -> 3072], [2000 -> 3072], step1: [1800 -> 3072], [0 -> 3072]
81+
padding_tokens: 3072 - 2500 + 3072 + 3072 - 2800 + 3072 - 1500 + 3072 - 2100 + 3072 - 2000 + 3072 - 1800 + 3072 = 11876
82+
- Balance:
83+
1. DP Rank 均衡切分:
84+
rank0: [2800, 1500, 1000, 1000]
85+
rank1: [2000, 2100, 1500, 800]
86+
2. optimizer_steps 均衡切分:
87+
rank0: [2800], [1500, 1000, 1000]
88+
rank1: [2100, 800], [2000, 1500]
89+
3. pack and padding:
90+
rank0: step0: [2800 -> 3072], step1: [2500 -> 3072], [1000 -> 3072],
91+
rank1: step0: [2900 -> 3072], step1: [2000 -> 3072], [1500 -> 3072]
92+
4. 跨卡对齐pack数量:
93+
skip
94+
padding_tokens: 3072 - 2800 + 3072 - 2500 + 3072 - 1000 + 3072 - 2900 + 3072 - 2000 + 3072 - 1500 = 5732
95+
- Greedy (贪心优先打包): 追求 Pack 填充率最大化,不维护样本在 Step 间的顺序
96+
1. pack and padding:
97+
Pack 1: [1500, 1000] -> [2500 -> 3072]
98+
Pack 2: [2800] -> [2800 -> 3072]
99+
Pack 3: [1500] -> [1500 -> 3072]
100+
Pack 4: [2000] -> [2000 -> 3072]
101+
Pack 5: [2100] -> [2100 -> 3072]
102+
Pack 6: [1000, 800] -> [1800 -> 3072]
103+
2. DP 切分:
104+
rank0: [Pack 1, Pack 2, Pack 3]
105+
rank1: [Pack 4, Pack 5, Pack 6]
106+
3. Opitmizer steps 切分:
107+
rank0: step0: [Pack 1, Pack 2], step1: [Pack 3]
108+
rank1: step0: [Pack 4, Pack 5], step1: [Pack 6]
109+
4. 跨卡对齐pack数量:
110+
skip
111+
padding_tokens: 3072 - 2500 + 3072 - 2800 + 3072 - 1500 + 3072 - 2000 + 3072 - 2100 + 3072 - 1800 = 5732
112+
"""
113+
lengths = [1500, 1000, 2800, 1500, 2000, 2100, 1000, 800]
114+
self._run_strategy_test("native", 2, 2, lengths, 11876)
115+
self._run_strategy_test("balance", 2, 2, lengths, 5732)
116+
self._run_strategy_test("greedy", 2, 2, lengths, 5732)
117+
118+
def test_imbalance_dp_size(self):
119+
lengths = [500]
120+
for strat in ["native", "balance", "greedy"]:
121+
self._run_strategy_test(strat, 2, 1, lengths, 5644)
122+
123+
def test_imbalanced_steps(self):
124+
lengths = [100, 200, 2500, 3000, 50, 400, 1000, 1500]
125+
self._run_strategy_test("native", 2, 4, lengths, 15826)
126+
self._run_strategy_test("balance", 2, 4, lengths, 15826)
127+
self._run_strategy_test("greedy", 2, 4, lengths, 3538)
128+
129+
130+
if __name__ == "__main__":
131+
unittest.main()

0 commit comments

Comments
 (0)