Skip to content

Commit bedd4d4

Browse files
committed
add ut and unify training input to WorkerLogItem
1 parent 1695d5a commit bedd4d4

File tree

5 files changed

+156
-25
lines changed

5 files changed

+156
-25
lines changed

tests/ray/test_grpo_train.py

Lines changed: 122 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,6 @@ def setUp(self):
5555
rewards = [item['reward'] for item in group]
5656
rewards = torch.tensor(rewards, dtype=torch.float32)
5757
advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8)
58-
5958
for i in range(self.prompt_repeat_k):
6059
item = group[i]
6160
response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist()
@@ -67,7 +66,7 @@ def setUp(self):
6766
dict(
6867
seq_ctx=SequenceContext.from_input_ids((input_ids, ), device="cpu"),
6968
shifted_labels=shifted_labels,
70-
advantage=advantages[i].item(),
69+
advantages=advantages[i],
7170
)
7271
)
7372
self.data_batches = data_batches
@@ -126,8 +125,125 @@ def build_train_controller(self):
126125
ray.get(train_controller.__ray_ready__.remote())
127126
return train_controller
128127

129-
def test_grpo_train_and_save(self):
128+
# def test_grpo_train_and_save(self):
129+
# train_controller = self.build_train_controller()
130+
# ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=8192, rollout_idx=0))
131+
# save_path = os.path.join(self.temp_dir, "hf_test")
132+
# ray.get(train_controller.save_hf.remote(str(save_path)))
133+
134+
def _create_dummy_item(self, length: int):
135+
"""Helper to create a dummy WorkerInputItem"""
136+
input_ids = torch.ones(1, length, dtype=torch.long)
137+
cu_seq_lens_q = torch.tensor([0, length], dtype=torch.int32)
138+
cu_seq_lens_k = torch.tensor([0, length], dtype=torch.int32)
139+
max_length_q = torch.tensor(length, dtype=torch.int32)
140+
max_length_k = torch.tensor(length, dtype=torch.int32)
141+
seq_ctx = SequenceContext(
142+
input_ids=input_ids,
143+
cu_seq_lens_q=cu_seq_lens_q,
144+
cu_seq_lens_k=cu_seq_lens_k,
145+
max_length_q=max_length_q,
146+
max_length_k=max_length_k,
147+
num_padding=0,
148+
device="cpu",
149+
)
150+
return {
151+
"seq_ctx": seq_ctx,
152+
"shifted_labels": torch.ones(1, length, dtype=torch.long),
153+
"advantages": torch.rand(1, 1, dtype=torch.float),
154+
"rollout_logprobs": torch.ones(1, length, dtype=torch.float),
155+
}
156+
157+
def test_controller_logic(self):
158+
"""
159+
Unit tests for RawTrainingController internal logic using the real Ray actor:
160+
- _balance_split_batch
161+
- _create_padding_item
162+
- _rearrange_batch_for_pack
163+
- _pad_and_pack_batches
164+
"""
165+
# 1. Build the real train controller
130166
train_controller = self.build_train_controller()
131-
ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0))
132-
save_path = os.path.join(self.temp_dir, "hf_test")
133-
ray.get(train_controller.save_hf.remote(str(save_path)))
167+
pack_max_length = 100
168+
169+
# --- Test 1: _balance_split_batch ---
170+
print("Testing _balance_split_batch...")
171+
# Input: 4 items with lengths 10, 20, 30, 40
172+
items = [self._create_dummy_item(l) for l in [10, 20, 30, 40]]
173+
dp_size = 2
174+
175+
# Call remote method
176+
# 10, 20, 30, 40 -> sum 100 -> avg 50.
177+
# Expected split: [10, 40] (sum 50) and [20, 30] (sum 50)
178+
result = ray.get(train_controller._balance_split_batch.remote(items, dp_size))
179+
180+
self.assertEqual(len(result), 2)
181+
self.assertEqual(len(result[0]), 2)
182+
self.assertEqual(len(result[1]), 2)
183+
184+
# Verify balance
185+
len_group0 = sum(item["seq_ctx"].input_ids.shape[1] for item in result[0])
186+
len_group1 = sum(item["seq_ctx"].input_ids.shape[1] for item in result[1])
187+
self.assertEqual(len_group0, 50)
188+
self.assertEqual(len_group1, 50)
189+
190+
# --- Test 2: _rearrange_batch_for_pack ---
191+
print("Testing _rearrange_batch_for_pack...")
192+
# Input: [40, 40, 30], max=100. With get_seqlen_balanced_partitions, it should be packed as [40, 30] and [40]
193+
items_pack = [self._create_dummy_item(l) for l in [40, 40, 30]]
194+
batches = ray.get(train_controller._rearrange_batch_for_pack.remote(items_pack, pack_max_length))
195+
196+
self.assertEqual(len(batches), 2)
197+
self.assertEqual(len(batches[0]), 2) # 40 + 30 = 70
198+
self.assertEqual(len(batches[1]), 1) # 40
199+
self.assertEqual(batches[0][0]["seq_ctx"].input_ids.shape[1] + batches[0][1]["seq_ctx"].input_ids.shape[1], 70)
200+
self.assertEqual(batches[1][0]["seq_ctx"].input_ids.shape[1], 40)
201+
# --- Test 3: _pad_and_pack_batches ---
202+
print("Testing _pad_and_pack_batches...")
203+
# Input: First batch with length 70. Should pad 30 to reach 100. Second batch with length 40, should pad 60 to reach 100.
204+
for idx, batch4pack_list in enumerate(batches):
205+
packed_item = ray.get(train_controller._pad_and_pack_batches.remote(batch4pack_list, pack_max_length))
206+
# Check total length
207+
self.assertEqual(packed_item["seq_ctx"].input_ids.shape[1], pack_max_length)
208+
# idx == 0:
209+
if idx == 0:
210+
# Check cu_seq_lens_q: [0, 40, 70, 100]
211+
expected_cu_lens = torch.tensor([0, 40, 70, 100], dtype=torch.int32)
212+
self.assertTrue(torch.equal(packed_item["seq_ctx"].cu_seq_lens_q, expected_cu_lens))
213+
# Check padding labels are -100
214+
self.assertTrue(torch.all(packed_item["shifted_labels"][0, 70:] == -100))
215+
if idx == 1:
216+
# Check cu_seq_lens_q: [0, 40, 100]
217+
expected_cu_lens = torch.tensor([0, 40, 100], dtype=torch.int32)
218+
self.assertTrue(torch.equal(packed_item["seq_ctx"].cu_seq_lens_q, expected_cu_lens))
219+
# Check padding labels are -100
220+
self.assertTrue(torch.all(packed_item["shifted_labels"][0, 40:] == -100))
221+
222+
# --- Test 4: _pad_to_max_packs_across_workes ---
223+
pack_dummy = {"dummy": "pack"}
224+
packed_data_batches = [
225+
[[pack_dummy, pack_dummy]], # Worker 0: 2 packs
226+
[[pack_dummy]] # Worker 1: 1 pack
227+
]
228+
# Execute the function locally
229+
packed_data_batches = ray.get(train_controller._pad_to_max_packs_across_workes.remote(
230+
packed_data_batches, 0, 2, pack_max_length
231+
))
232+
# Verification
233+
# Worker 0 should still have 2 packs
234+
self.assertEqual(len(packed_data_batches[0][0]), 2)
235+
236+
# Worker 1 should now have 2 packs (1 original + 1 padding)
237+
self.assertEqual(len(packed_data_batches[1][0]), 2)
238+
239+
# Verify the added item is a padding item
240+
added_pack = packed_data_batches[1][0][1]
241+
# Since we used the real _create_padding_item, it should have the correct structure
242+
self.assertIn("seq_ctx", added_pack)
243+
self.assertIn("shifted_labels", added_pack)
244+
self.assertEqual(added_pack["seq_ctx"].input_ids.shape[1], pack_max_length)
245+
self.assertTrue(torch.all(added_pack["shifted_labels"] == -100))
246+
print("All controller logic tests passed!")
247+
248+
if __name__ == "__main__":
249+
unittest.main()

xtuner/v1/rl/base/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
from .controller import TrainingController, TrainingControllerProxy, TrainingStepTimeLog
22
from .loss import BaseRLLossConfig, RLLossContextInputItem
3-
from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem
3+
from .worker import (
4+
TrainingWorker,
5+
TrainingWorkerClass,
6+
TrainingWorkerProxy,
7+
WorkerConfig,
8+
WorkerInputItem,
9+
WorkerLogItem,
10+
)
411

512

613
__all__ = [
@@ -14,4 +21,5 @@
1421
"RLLossContextInputItem",
1522
"WorkerLogItem",
1623
"TrainingStepTimeLog",
24+
"WorkerInputItem",
1725
]

xtuner/v1/rl/base/controller.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717
from xtuner.v1.train.trainer import LoadCheckpointConfig
1818
from xtuner.v1.utils import get_logger, ray_method
1919

20+
from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem
2021

21-
TRAIN_RAY_GET_TIMEOUT = os.getenv("XTUNER_TRAIN_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours
2222

23-
from .worker import TrainingWorker, WorkerInputItem, WorkerLogItem
23+
TRAIN_RAY_GET_TIMEOUT = os.getenv("XTUNER_TRAIN_RAY_GET_TIMEOUT", 5 * 3600) # default 5 hours
2424

2525

2626
class TrainingStepTimeLog(TypedDict):
@@ -314,7 +314,10 @@ def _set_data_batches_properties(self, data_batches: list[WorkerInputItem]):
314314
def _pad_and_pack_batches(self, batch4pack: list[WorkerInputItem], pack_max_length: int) -> WorkerInputItem:
315315
seq_ctx_list = [item["seq_ctx"] for item in batch4pack]
316316
label_list = [item["shifted_labels"] for item in batch4pack]
317-
advantage_list = [torch.tensor([item["advantages"]]).float().unsqueeze(0) for item in batch4pack]
317+
advantage_list = []
318+
for item in batch4pack:
319+
advantages = item["advantages"].reshape(1, -1)
320+
advantage_list.append(advantages)
318321
rollout_logprobs_list = [
319322
item["rollout_logprobs"] if self.has_rollout_logprobs else None for item in batch4pack
320323
]
@@ -366,6 +369,7 @@ def _pad_to_max_packs_across_workes(
366369
padding_item = self._create_padding_item(pack_max_length, pack_max_length)
367370
padding_items = [padding_item for _ in range(num_padding_packs)]
368371
packed_data_batches[dp_rank][step_idx].extend(padding_items)
372+
return packed_data_batches
369373

370374
@ray_method
371375
def fit(
@@ -428,7 +432,9 @@ def fit(
428432
# padding for each worker to have same number of packs in each optimizer step
429433
for step_idx in range(optimizer_steps):
430434
max_packs = max_packs_per_step[step_idx]
431-
self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)
435+
packed_data_batches = self._pad_to_max_packs_across_workes(
436+
packed_data_batches, step_idx, max_packs, pack_max_length
437+
)
432438

433439
pack_end_time = time.perf_counter()
434440
self.logger.info(f"Data packing took {pack_end_time - pack_start_time:.2f} seconds.")

xtuner/v1/rl/base/worker.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -516,13 +516,13 @@ def _apply_rollout_is_correction(
516516
all_rollout_is_metrics.append(rollout_is_metrics)
517517
all_mismatch_metrics.append(mismatch_metrics)
518518

519-
worker_log_item: WorkerLogItem = {"train_entropy": 0.0, "train_metrics": [], "sft_train_metrics": {}}
520-
logger_msg = f"Rollout {rollout_idx}: "
521-
sum_entropy = cast(torch.Tensor, sum_entropy)
522-
dist.all_reduce(sum_entropy, op=dist.ReduceOp.SUM)
523-
avg_sum_entropy = sum_entropy / global_grad_tokens if global_grad_tokens > 0 else torch.tensor(0.0)
524-
worker_log_item["train_entropy"] = avg_sum_entropy.item()
525-
logger_msg += f"avg entropy: {avg_sum_entropy:.4f}"
519+
metrics = {
520+
"sum_entropy": sum_entropy,
521+
"sum_rollout_entropy": sum_rollout_entropy,
522+
"all_mismatch_metrics": all_mismatch_metrics,
523+
"all_rollout_is_metrics": all_rollout_is_metrics,
524+
}
525+
return loss_ctx_input_list, metrics
526526

527527
@ray_method
528528
def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int):
@@ -579,10 +579,7 @@ def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int):
579579
global_grad_tokens = rank_grad_tokens.clone()
580580
dist.all_reduce(global_grad_tokens, op=dist.ReduceOp.SUM)
581581

582-
worker_log_item: WorkerLogItem = {
583-
"train_entropy": 0.0,
584-
"train_metrics": [],
585-
}
582+
worker_log_item: WorkerLogItem = {"train_entropy": 0.0, "train_metrics": [], "sft_train_metrics": {}}
586583
log_parts = []
587584

588585
sum_entropy = cast(torch.Tensor, metrics["sum_entropy"])
@@ -678,7 +675,10 @@ def fit(self, data_batches: list[list[WorkerInputItem]], rollout_idx: int):
678675
f"{key}={value:.4f}" if isinstance(value, float) else f"{key}={value}"
679676
for key, value in log_info.items()
680677
)
681-
log_str = f"Rank{self.rank} Rollout {rollout_idx} Step {i}: gradient_accumulation_steps={num_packs_this_step}" + log_str
678+
log_str = (
679+
f"Rank{self.rank} Rollout {rollout_idx} Step {i}: gradient_accumulation_steps={num_packs_this_step}, "
680+
+ log_str
681+
)
682682
self.logger.info(log_str)
683683

684684
self._rollout_step += 1

xtuner/v1/train/rl_trainer.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
TrainingWorkerClass,
3434
TrainingWorkerProxy,
3535
WorkerConfig,
36+
WorkerInputItem,
3637
WorkerLogItem,
3738
)
3839
from xtuner.v1.rl.base import TrainingWorker as BaseTrainingWorker
@@ -769,10 +770,10 @@ def _prepare_train_data(self, data_groups, pack_max_length, multimodal_train_inf
769770
rollout_logprobs = None
770771

771772
seq_ctx = get_train_seq_ctx(input_ids, multimodal_train_info, len(response_ids) - 1)
772-
data_dict = {
773+
data_dict: WorkerInputItem = {
773774
"seq_ctx": seq_ctx,
774775
"shifted_labels": shifted_labels,
775-
"advantage": advantages[i].item(),
776+
"advantages": advantages[i],
776777
"rollout_logprobs": rollout_logprobs,
777778
}
778779

0 commit comments

Comments
 (0)