-
Notifications
You must be signed in to change notification settings - Fork 398
[Refactor] refactor packing in RL train controller and train worker #1393
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull request overview
This PR refactors the packing logic in the RL training controller and worker components to improve token balancing and code organization. The key changes introduce a Karmarkar-Karp algorithm for balanced partitioning, extract helper methods for better code maintainability, and restructure how data batches are distributed across workers.
Key Changes
- Introduces sequence-length balanced partitioning using the Karmarkar-Karp differencing algorithm to better distribute workload across devices
- Refactors worker's
fitmethod to accept nested list structurelist[list[WorkerInputItem]]instead of flat list, aligning with the new per-step packing approach - Extracts reusable helper methods (
_resolve_ray_data,_apply_rollout_is_correction,_create_padding_sample,_pack,_balance_split_batch) to reduce code duplication and improve maintainability
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 12 comments.
| File | Description |
|---|---|
| xtuner/v1/rl/utils.py | Adds Karmarkar-Karp algorithm implementation with get_seqlen_balanced_partitions function for balanced workload distribution across partitions |
| xtuner/v1/rl/base/worker.py | Refactors fit method to handle nested batch structure, extracts ray data resolution and importance sampling logic into separate methods, adds get_worker_cfg accessor method |
| xtuner/v1/rl/base/controller.py | Major refactoring of packing logic with new balanced splitting, padding creation, and improved data distribution across workers with per-step gradient accumulation support |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| # Adapted from https://github.com/volcengine/verl/blob/main/verl/utils/seqlen_balancing.py | ||
| def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool): | ||
| # see: https://en.wikipedia.org/wiki/Largest_differencing_method | ||
| class Set: |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class implements lt, but does not implement le or ge.
| return len(self.items) < len(other.items) | ||
| return self.items < other.items | ||
|
|
||
| class State: |
Copilot
AI
Dec 24, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This class implements lt, but does not implement le or ge.
ce11425 to
62ae9fc
Compare
xtuner/v1/rl/base/controller.py
Outdated
| get_logger().info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}") | ||
|
|
||
| packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)] | ||
| max_packs_per_card = [0] * optimizer_steps |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rename to max_packed_batch_num_per_step
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
max_packs_per_step 更加准确一些:每步最大的packs数
|
|
||
| # old logprobs are inplaced updated in compute_actor_logprobs | ||
| loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) | ||
| loss_ctx_input_list, metrics = self._apply_rollout_is_correction( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Great!原来很长的fit函数变得有层次更易读了
Great PR description! Then I can review easily as the same order above : ) Additionally, when you write the core function calling chain responding to your original design in the "Key Changes", you will find that there are some high-level functions missing in your implementation, just like the The unit test can play the same role sometimes. For example, if you want to write unit test to test the core padding function, then you need to abstract the related code pieces into the function |
62ae9fc to
2dfbb8b
Compare
2dfbb8b to
ca54108
Compare
ca54108 to
bedd4d4
Compare
| del data_batches | ||
|
|
||
| # old logprobs are inplaced updated in compute_actor_logprobs | ||
| loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里有一个优化项。self._resolve_ray_data 是一个相对耗时的操作,可以和 self.compute_actor_logprobs overlap 计算,从而掩盖掉跨节点数据读取开销。
具体咋写可能有点麻烦,如果暂时不想改,可以加一个 TODO
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 先写todo了
xtuner/v1/rl/base/controller.py
Outdated
| assert world_size % self.data_replicate_size == 0, "world_size must be divisible by data_replicate_size" | ||
| optimizer_steps = self.worker_cfg.optimizer_steps | ||
|
|
||
| batches_per_dp_group: list[list[WorkerInputItem]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
可能有些 corner case 没有考虑。比如 optimizer_steps=16,但是数据条数不够 16,代码是否会报错。建议这种可以写严谨的单元测试来覆盖
| handles.append( | ||
| worker.fit.remote( # type: ignore[attr-defined] | ||
| data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size], | ||
| data_batches=packed_data_batches[worker_idx // self.data_replicate_size], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
::dp_size 这个逻辑不能去掉
|
|
Motivation
Current xtuner data distribution mechanism has a pack allocation issue that leads to unstable training steps and affects training effectiveness.
The data distribution pipeline consists of three stages:
data_batchby token count, creating one pack per 32K tokens, resulting in N packsoptimizer_stepparameterWhen
N/Mis not divisible byoptimizer_step, the actual training steps fail to match the expected value.For example:
Key Changes
1. Token-aware Pre-allocation
In
RawTrainingController.fit()(controller.py), samples are evenly distributed intoMworkers and further split intooptimizer_stepbuckets for each worker, based on token count. This ensures balanced token distribution across all workers and steps:batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)2. Pack & Pad per Bucket
Within each pre-allocated bucket, data is packed and padded so that each pack does not exceed
pack_max_length. Padding is applied where necessary, and the number of packs per step is aligned across all workers:batch4pack_list = self._rearrange_batch_for_pack(step_mini_batch, pack_max_length)step_pack = self._pad_and_pack_batches(batch4pack, pack_max_length)self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)3. Worker-side Training
In
TrainingWorker.fit()(worker.py), each worker processes its assigned data, including sequence context resolution, logprobs computation, importance sampling correction, and the actual training step:seq_ctx = self._resolve_ray_data(data["seq_ctx"], language_cfg)self.compute_actor_logprobs()self._apply_rollout_is_correction()train_step()