Skip to content

Conversation

@YanhuiDua
Copy link
Collaborator

@YanhuiDua YanhuiDua commented Dec 24, 2025

Motivation

Current xtuner data distribution mechanism has a pack allocation issue that leads to unstable training steps and affects training effectiveness.

The data distribution pipeline consists of three stages:

  1. Packing Stage: Split input data_batch by token count, creating one pack per 32K tokens, resulting in N packs
  2. Distribution Stage: Evenly distribute N packs across M workers, each worker receives N/M packs
  3. Step Division: Divide packs per worker into steps based on optimizer_step parameter

When N/M is not divisible by optimizer_step, the actual training steps fail to match the expected value.
For example:

N/M = 44                          # packs per worker
optimizer_step = 16               # expected training steps
packs_per_step =44/16= 3      # packs allocated per step

# Actual result:
actual_steps =44/3= 14        # complete steps
# Total: 15 steps with inconsistent batch sizes

Key Changes

1. Token-aware Pre-allocation

In RawTrainingController.fit() (controller.py), samples are evenly distributed into M workers and further split into optimizer_step buckets for each worker, based on token count. This ensures balanced token distribution across all workers and steps:

  • batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)
  • mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)

2. Pack & Pad per Bucket

Within each pre-allocated bucket, data is packed and padded so that each pack does not exceed pack_max_length. Padding is applied where necessary, and the number of packs per step is aligned across all workers:

  • batch4pack_list = self._rearrange_batch_for_pack(step_mini_batch, pack_max_length)
  • step_pack = self._pad_and_pack_batches(batch4pack, pack_max_length)
  • self._pad_to_max_packs_across_workes(packed_data_batches, step_idx, max_packs, pack_max_length)

3. Worker-side Training

In TrainingWorker.fit() (worker.py), each worker processes its assigned data, including sequence context resolution, logprobs computation, importance sampling correction, and the actual training step:

  • seq_ctx = self._resolve_ray_data(data["seq_ctx"], language_cfg)
  • self.compute_actor_logprobs()
  • self._apply_rollout_is_correction()
  • train_step()

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors the packing logic in the RL training controller and worker components to improve token balancing and code organization. The key changes introduce a Karmarkar-Karp algorithm for balanced partitioning, extract helper methods for better code maintainability, and restructure how data batches are distributed across workers.

Key Changes

  • Introduces sequence-length balanced partitioning using the Karmarkar-Karp differencing algorithm to better distribute workload across devices
  • Refactors worker's fit method to accept nested list structure list[list[WorkerInputItem]] instead of flat list, aligning with the new per-step packing approach
  • Extracts reusable helper methods (_resolve_ray_data, _apply_rollout_is_correction, _create_padding_sample, _pack, _balance_split_batch) to reduce code duplication and improve maintainability

Reviewed changes

Copilot reviewed 3 out of 3 changed files in this pull request and generated 12 comments.

File Description
xtuner/v1/rl/utils.py Adds Karmarkar-Karp algorithm implementation with get_seqlen_balanced_partitions function for balanced workload distribution across partitions
xtuner/v1/rl/base/worker.py Refactors fit method to handle nested batch structure, extracts ray data resolution and importance sampling logic into separate methods, adds get_worker_cfg accessor method
xtuner/v1/rl/base/controller.py Major refactoring of packing logic with new balanced splitting, padding creation, and improved data distribution across workers with per-step gradient accumulation support

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

# Adapted from https://github.com/volcengine/verl/blob/main/verl/utils/seqlen_balancing.py
def karmarkar_karp(seqlen_list: list[int], k_partitions: int, equal_size: bool):
# see: https://en.wikipedia.org/wiki/Largest_differencing_method
class Set:
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class implements lt, but does not implement le or ge.

Copilot uses AI. Check for mistakes.
return len(self.items) < len(other.items)
return self.items < other.items

class State:
Copy link

Copilot AI Dec 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This class implements lt, but does not implement le or ge.

Copilot uses AI. Check for mistakes.
get_logger().info(f"default split into {dp_size} partitions with tokens: {tokens_in_partition}")

packed_data_batches: list[list[list[dict]]] = [[[] for _ in range(optimizer_steps)] for _ in range(dp_size)]
max_packs_per_card = [0] * optimizer_steps
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rename to max_packed_batch_num_per_step

Copy link
Collaborator Author

@YanhuiDua YanhuiDua Dec 29, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_packs_per_step 更加准确一些:每步最大的packs数


# old logprobs are inplaced updated in compute_actor_logprobs
loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list)
loss_ctx_input_list, metrics = self._apply_rollout_is_correction(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great!原来很长的fit函数变得有层次更易读了

@jayhenry
Copy link
Collaborator

jayhenry commented Dec 26, 2025

Motivation

Current xtuner data distribution mechanism has a pack allocation issue that leads to unstable training steps and affects training effectiveness.

The data distribution pipeline consists of three stages:

  1. Packing Stage: Split input data_batch by token count, creating one pack per 32K tokens, resulting in N packs
  2. Distribution Stage: Evenly distribute N packs across M workers, each worker receives N/M packs
  3. Step Division: Divide packs per worker into steps based on optimizer_step parameter

When N/M is not divisible by optimizer_step, the actual training steps fail to match the expected value. For example:

N/M = 44                          # packs per worker
optimizer_step = 16               # expected training steps
packs_per_step =44/16= 3      # packs allocated per step

# Actual result:
actual_steps =44/3= 14        # complete steps
# Total: 15 steps with inconsistent batch sizes

Key Changes

This PR refactors the pipeline to: Allocate → Pack & Pad and wrappers some methods from TrainController and TrainWorker

  1. Token-aware pre-allocation:Evenly distribute samples into M workers (optional) × optimizer_step buckets based on token count
  2. Pack & pad per bucket: Apply packing and padding within each pre-allocated bucket

Great PR description!
Maybe you can add the calling chain of core packing functions responding to the workflow in key changes, such as

controller.py: 
RawTrainingController.fit() 
# 1. Token-aware pre-allocation:Evenly distribute samples into M workers (optional) × optimizer_step buckets based on token count
-> batches_per_dp_group = self._balance_split_batch(data_batches, dp_size)
-> mini_batch_for_steps = self._balance_split_batch(dp_worker_data_batches, optimizer_steps)
# 2. Pack & pad per bucket: Apply packing and padding within each pre-allocated bucket
-> batch4pack_list = self._rearrange_batch_for_pack(step_mini_batch, pack_max_length)   # the old version: pack_mini_batch = self._pack(step_mini_batch, pack_max_length)
-> self._pack_batches()  # pieces of packing code which is better to be wrapped in the new function `_pack_batches()`
worker.py:
-> self._create_padding_sample()  # pieces of padding code which is better to be wrapped in a new function `_pad_batches()`
# 3. use the packed padded data batches
->TrainingWorker.fit()

Then I can review easily as the same order above : )

Additionally, when you write the core function calling chain responding to your original design in the "Key Changes", you will find that there are some high-level functions missing in your implementation, just like the _pad_batches(). If you add the high-level function, then you can tell others more clearly, and others can read code more easily, because code readers can read and think this in the high level ignoring the messy details.

The unit test can play the same role sometimes. For example, if you want to write unit test to test the core padding function, then you need to abstract the related code pieces into the function _pad_batches() and test it.

del data_batches

# old logprobs are inplaced updated in compute_actor_logprobs
loss_ctx_input_list = self.compute_actor_logprobs(seq_ctx_list, loss_ctx_input_list)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里有一个优化项。self._resolve_ray_data 是一个相对耗时的操作,可以和 self.compute_actor_logprobs overlap 计算,从而掩盖掉跨节点数据读取开销。

具体咋写可能有点麻烦,如果暂时不想改,可以加一个 TODO

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, 先写todo了

assert world_size % self.data_replicate_size == 0, "world_size must be divisible by data_replicate_size"
optimizer_steps = self.worker_cfg.optimizer_steps

batches_per_dp_group: list[list[WorkerInputItem]]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能有些 corner case 没有考虑。比如 optimizer_steps=16,但是数据条数不够 16,代码是否会报错。建议这种可以写严谨的单元测试来覆盖

handles.append(
worker.fit.remote( # type: ignore[attr-defined]
data_batches=packed_data_batches[(worker_idx // data_replicate_size) :: dp_size],
data_batches=packed_data_batches[worker_idx // self.data_replicate_size],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

::dp_size 这个逻辑不能去掉

@hhaAndroid
Copy link
Collaborator

rl_trainer 里面有 random.shuffle(data_batches) 要删掉

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants