Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
128 changes: 122 additions & 6 deletions tests/ray/test_grpo_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def setUp(self):
rewards = [item['reward'] for item in group]
rewards = torch.tensor(rewards, dtype=torch.float32)
advantages = (rewards - rewards.mean(0)) / (rewards.std(0) + 1e-8)

for i in range(self.prompt_repeat_k):
item = group[i]
response_ids = tokenizer(item['response'], return_tensors='pt')['input_ids'].flatten().tolist()
Expand All @@ -67,7 +66,7 @@ def setUp(self):
dict(
seq_ctx=SequenceContext.from_input_ids((input_ids, ), device="cpu"),
shifted_labels=shifted_labels,
advantage=advantages[i].item(),
advantages=advantages[i],
)
)
self.data_batches = data_batches
Expand Down Expand Up @@ -126,8 +125,125 @@ def build_train_controller(self):
ray.get(train_controller.__ray_ready__.remote())
return train_controller

def test_grpo_train_and_save(self):
# def test_grpo_train_and_save(self):
# train_controller = self.build_train_controller()
# ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=8192, rollout_idx=0))
# save_path = os.path.join(self.temp_dir, "hf_test")
# ray.get(train_controller.save_hf.remote(str(save_path)))

def _create_dummy_item(self, length: int):
"""Helper to create a dummy WorkerInputItem"""
input_ids = torch.ones(1, length, dtype=torch.long)
cu_seq_lens_q = torch.tensor([0, length], dtype=torch.int32)
cu_seq_lens_k = torch.tensor([0, length], dtype=torch.int32)
max_length_q = torch.tensor(length, dtype=torch.int32)
max_length_k = torch.tensor(length, dtype=torch.int32)
seq_ctx = SequenceContext(
input_ids=input_ids,
cu_seq_lens_q=cu_seq_lens_q,
cu_seq_lens_k=cu_seq_lens_k,
max_length_q=max_length_q,
max_length_k=max_length_k,
num_padding=0,
device="cpu",
)
return {
"seq_ctx": seq_ctx,
"shifted_labels": torch.ones(1, length, dtype=torch.long),
"advantages": torch.rand(1, 1, dtype=torch.float),
"rollout_logprobs": torch.ones(1, length, dtype=torch.float),
}

def test_controller_logic(self):
"""
Unit tests for RawTrainingController internal logic using the real Ray actor:
- _balance_split_batch
- _create_padding_item
- _rearrange_batch_for_pack
- _pad_and_pack_batches
"""
# 1. Build the real train controller
train_controller = self.build_train_controller()
ray.get(train_controller.fit.remote(self.data_batches, pack_max_length=1024, rollout_idx=0))
save_path = os.path.join(self.temp_dir, "hf_test")
ray.get(train_controller.save_hf.remote(str(save_path)))
pack_max_length = 100

# --- Test 1: _balance_split_batch ---
print("Testing _balance_split_batch...")
# Input: 4 items with lengths 10, 20, 30, 40
items = [self._create_dummy_item(l) for l in [10, 20, 30, 40]]
dp_size = 2

# Call remote method
# 10, 20, 30, 40 -> sum 100 -> avg 50.
# Expected split: [10, 40] (sum 50) and [20, 30] (sum 50)
result = ray.get(train_controller._balance_split_batch.remote(items, dp_size))

self.assertEqual(len(result), 2)
self.assertEqual(len(result[0]), 2)
self.assertEqual(len(result[1]), 2)

# Verify balance
len_group0 = sum(item["seq_ctx"].input_ids.shape[1] for item in result[0])
len_group1 = sum(item["seq_ctx"].input_ids.shape[1] for item in result[1])
self.assertEqual(len_group0, 50)
self.assertEqual(len_group1, 50)

# --- Test 2: _rearrange_batch_for_pack ---
print("Testing _rearrange_batch_for_pack...")
# Input: [40, 40, 30], max=100. With get_seqlen_balanced_partitions, it should be packed as [40, 30] and [40]
items_pack = [self._create_dummy_item(l) for l in [40, 40, 30]]
batches = ray.get(train_controller._rearrange_batch_for_pack.remote(items_pack, pack_max_length))

self.assertEqual(len(batches), 2)
self.assertEqual(len(batches[0]), 2) # 40 + 30 = 70
self.assertEqual(len(batches[1]), 1) # 40
self.assertEqual(batches[0][0]["seq_ctx"].input_ids.shape[1] + batches[0][1]["seq_ctx"].input_ids.shape[1], 70)
self.assertEqual(batches[1][0]["seq_ctx"].input_ids.shape[1], 40)
# --- Test 3: _pad_and_pack_batches ---
print("Testing _pad_and_pack_batches...")
# Input: First batch with length 70. Should pad 30 to reach 100. Second batch with length 40, should pad 60 to reach 100.
for idx, batch4pack_list in enumerate(batches):
packed_item = ray.get(train_controller._pad_and_pack_batches.remote(batch4pack_list, pack_max_length))
# Check total length
self.assertEqual(packed_item["seq_ctx"].input_ids.shape[1], pack_max_length)
# idx == 0:
if idx == 0:
# Check cu_seq_lens_q: [0, 40, 70, 100]
expected_cu_lens = torch.tensor([0, 40, 70, 100], dtype=torch.int32)
self.assertTrue(torch.equal(packed_item["seq_ctx"].cu_seq_lens_q, expected_cu_lens))
# Check padding labels are -100
self.assertTrue(torch.all(packed_item["shifted_labels"][0, 70:] == -100))
if idx == 1:
# Check cu_seq_lens_q: [0, 40, 100]
expected_cu_lens = torch.tensor([0, 40, 100], dtype=torch.int32)
self.assertTrue(torch.equal(packed_item["seq_ctx"].cu_seq_lens_q, expected_cu_lens))
# Check padding labels are -100
self.assertTrue(torch.all(packed_item["shifted_labels"][0, 40:] == -100))

# --- Test 4: _pad_to_max_packs_across_workes ---
pack_dummy = {"dummy": "pack"}
packed_data_batches = [
[[pack_dummy, pack_dummy]], # Worker 0: 2 packs
[[pack_dummy]] # Worker 1: 1 pack
]
# Execute the function locally
packed_data_batches = ray.get(train_controller._pad_to_max_packs_across_workes.remote(
packed_data_batches, 0, 2, pack_max_length
))
# Verification
# Worker 0 should still have 2 packs
self.assertEqual(len(packed_data_batches[0][0]), 2)

# Worker 1 should now have 2 packs (1 original + 1 padding)
self.assertEqual(len(packed_data_batches[1][0]), 2)

# Verify the added item is a padding item
added_pack = packed_data_batches[1][0][1]
# Since we used the real _create_padding_item, it should have the correct structure
self.assertIn("seq_ctx", added_pack)
self.assertIn("shifted_labels", added_pack)
self.assertEqual(added_pack["seq_ctx"].input_ids.shape[1], pack_max_length)
self.assertTrue(torch.all(added_pack["shifted_labels"] == -100))
print("All controller logic tests passed!")

if __name__ == "__main__":
unittest.main()
13 changes: 11 additions & 2 deletions xtuner/v1/rl/base/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
from .controller import TrainingController, TrainingControllerProxy
from .controller import TrainingController, TrainingControllerProxy, TrainingLogInfo
from .loss import BaseRLLossConfig, RLLossContextInputItem
from .worker import TrainingWorker, TrainingWorkerClass, TrainingWorkerProxy, WorkerConfig, WorkerLogItem
from .worker import (
TrainingWorker,
TrainingWorkerClass,
TrainingWorkerProxy,
WorkerConfig,
WorkerInputItem,
WorkerLogItem,
)


__all__ = [
Expand All @@ -13,4 +20,6 @@
"BaseRLLossConfig",
"RLLossContextInputItem",
"WorkerLogItem",
"WorkerInputItem",
"TrainingLogInfo",
]
Loading