Skip to content

Conversation

@xiaoyao0115
Copy link
Contributor

@xiaoyao0115 xiaoyao0115 commented Oct 28, 2025

This PR is the second part of hybrid-cp. The first part is: #2054
(PR for main branch: #2304

Compared to part 1, this PR adds the following:​

  • Added support for SFT datasets and sequence packing, along with a script. With these additions, hybrid-cp can run end-to-end.​ Convergence has been verified on qwen3-30B on 32 GPUs, with max_seqlen is set to 12288, and max_seqlen_per_dp_cp_rank is set to 3072. In the figure below, 'bshd' refers to running with CP=4, where sequences are padded to max_seqlen and executed in the same bshd format as in pretraining. 'thd-packing' refers to using CP=4 while packing variable-length sequences. In 'hybrid-cp', the maximum CP group size is also 4.
image image
  • Added a mock SFT dataset that lets users control sequence lengths by specifying a sequence-length distribution or by providing a file containing sequence lengths.​
  • Migrated the hybrid-cp and sequence packing changes into a dataiterator_wrapper to minimize code changes. Adding a new scheduling algorithm now only requires adding a new scheduler class, which keeps the logic clear and easier to maintain.​
  • Added support for FSDP with hybrid-cp; the loss curve is shown below.(model : Qwen3-30B-A3B, hybrid-cp size : 4)
  • Added support for PP, but does not support for FSDP+PP.

There's many improvements that we want to make in the future releases.

  1. The feature is limited to creating dynamic groups of CP of power 2. We hope to add complete dynamic support using changes in TransformerEngine DPA.
  2. The feature does not support CUDA graphs.
  3. The feature works best with FlashAttention instead of cuDNN FusedAttention. This is because the changing lengths and CP size make cuDNN recompile the graph and all performance gains are lost. We'll advocate for dynamic support to cuDNN FusedAttention.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@xiaoyao0115 xiaoyao0115 requested review from a team as code owners October 28, 2025 08:57
@xiaoyao0115 xiaoyao0115 added the enhancement New feature or request label Oct 28, 2025
@copy-pr-bot
Copy link

copy-pr-bot bot commented Oct 28, 2025

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@xiaoyao0115 xiaoyao0115 force-pushed the hybrid-cp branch 3 times, most recently from f33edcd to 48e91d2 Compare November 2, 2025 09:33
@yanring yanring added module: moe dev branch Dev branch related issues and development labels Nov 5, 2025
@yanring
Copy link
Contributor

yanring commented Nov 7, 2025

Is there any difference between this and #2054?

@kunlunl
Copy link
Contributor

kunlunl commented Nov 7, 2025

Is there any difference between this and #2054?

This is the second MR, we need to merge 2054 first, and then this 2000 (The reason the second MR is 2000, while the first one is 2054 (>2000), is because they were migrated from GitLab at different times)

@yanring
Copy link
Contributor

yanring commented Nov 10, 2025

Is there any difference between this and #2054?

This is the second MR, we need to merge 2054 first, and then this 2000 (The reason the second MR is 2000, while the first one is 2054 (>2000), is because they were migrated from GitLab at different times)

Got it, thanks! Could you please update the title to reflect this?

@xiaoyao0115 xiaoyao0115 changed the title [Dev] feat: hybrid-cp feature for dev branch (Author: Parth Kunlun Tailai) [Dev] feat: hybrid-cp feature for dev branch (part 2) Nov 11, 2025
@xiaoyao0115 xiaoyao0115 changed the title [Dev] feat: hybrid-cp feature for dev branch (part 2) [Dev] feat: hybrid-cp for dev branch (part 2) Nov 11, 2025
@kunlunl
Copy link
Contributor

kunlunl commented Dec 1, 2025

/ok to test e0c90c5

@kunlunl
Copy link
Contributor

kunlunl commented Jan 13, 2026

/ok to test a9def58

@kunlunl
Copy link
Contributor

kunlunl commented Jan 13, 2026

/ok to test d12ccf1

@kunlunl
Copy link
Contributor

kunlunl commented Jan 13, 2026

/ok to test 86581cd

# during pipeline parallelism, it should not be set if sequence length
# is constant during training.
args.variable_seq_lengths = False
if args.sequence_packing:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please move these validations into transformer_config.

max_seqlen = torch.empty(
1,
dtype=torch.int32,
device=torch.cuda.current_device(),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify why these were removed?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To support pp, it would become more complex, so the thd related logic was moved to a new separate function get_batch_on_this_rank_for_sequence_packing.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The thd logic was added in the part-1 PR, it's just back to how it was before.

# Copyright (c) 2025 NVIDIA CORPORATION. All rights reserved.

from typing import Any, List, Optional
import enum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we put this big change in a separate file?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This data_schedule.py should be the separate file you want, it was included in the part-1 PR.

@ericharper
Copy link
Contributor

@asolergi-nv FYI

@parthmannan parthmannan mentioned this pull request Jan 15, 2026
6 tasks
Signed-off-by: xiaoyao0115 <1804647152@qq.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

dev branch Dev branch related issues and development enhancement New feature or request module: moe

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants