Skip to content

Enable CUDA graph for ADAM optimizer#3429

Merged
Phlip79 merged 9 commits intoNVIDIA:mainfrom
vasunvidia:vrengasamy/optimizer_cuda_graph_main
Apr 3, 2026
Merged

Enable CUDA graph for ADAM optimizer#3429
Phlip79 merged 9 commits intoNVIDIA:mainfrom
vasunvidia:vrengasamy/optimizer_cuda_graph_main

Conversation

@vasunvidia
Copy link
Copy Markdown
Contributor

Add wait stream before copying next batch to CG input

Add OptimizerCudaGraphWrapper to CUDA graph optimizer

Cleanup

What does this PR do ?

⚠️ For major changes (either in lines of code or in its impact), please make sure to first share a design doc with the team. If you're unsure what's the best way to do so, contact the @mcore-oncall.

Contribution process

flowchart LR
    A[Pre-checks] --> B[PR Tests]
    subgraph Code Review/Approval
        C1[Expert Review] --> C2[Final Review]
    end
    B --> C1
    C2 --> D[Merge]
Loading

Pre-checks

  • I want this PR in a versioned release and have added the appropriate Milestone (e.g., Core 0.8)
  • I have added relevant unit tests
  • I have added relevant functional tests
  • I have added proper typing to my code Typing guidelines
  • I have added relevant documentation
  • I have run the autoformatter.sh on my PR

Code review

The following process is enforced via the CODEOWNERS file for changes into megatron/core. For changes outside of megatron/core, it is up to the PR author whether or not to tag the Final Reviewer team.

For MRs into `main` branch

Feel free to message or comment the @mcore-oncall to help accelerate your merge into main. The less complex your PR is, the faster it will be approved and merged!

(Step 1): Add PR label Expert Review

(Step 2): Collect the expert reviewers reviews

  1. Attach the Expert Review label when your PR is ready for review.
  2. GitHub auto-assigns expert reviewers based on your changes. They will get notified and pick up your PR soon.

⚠️ Only proceed to the next step once all reviewers have approved, merge-conflict are resolved and the CI is passing.
Final Review might get declined if these requirements are not fulfilled.

(Step 3): Final Review

  1. Add Final Review label
  2. GitHub auto-assigns final reviewers based on your changes. They will get notified and pick up your PR soon.

(Optional Step 4): Cherry-pick into release branch

If this PR also needs to be merged into core_r* release branches, after this PR has been merged, select Cherry-pick to open a new PR into the release branch.

For MRs into `dev` branch The proposed review process for `dev` branch is under active discussion.

MRs are mergable after one approval by either eharper@nvidia.com or zijiey@nvidia.com.

Merging your PR

Any member of core-adlr and core-nemo will be able to merge your PR.

@vasunvidia vasunvidia requested review from a team as code owners February 14, 2026 05:55
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot bot commented Feb 14, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@ko3n1g ko3n1g requested a review from a team February 14, 2026 05:56
@Phlip79 Phlip79 added the Final Review PR is in the "final review" stage label Mar 4, 2026
Add wait stream before copying next batch to CG input

Add OptimizerCudaGraphWrapper to CUDA graph optimizer

Cleanup
@vasunvidia vasunvidia force-pushed the vrengasamy/optimizer_cuda_graph_main branch from 86eb0f6 to d64787e Compare March 19, 2026 17:19
@vasunvidia vasunvidia requested review from a team as code owners March 19, 2026 17:19
@svcnvidia-nemo-ci svcnvidia-nemo-ci removed the Final Review PR is in the "final review" stage label Mar 19, 2026
@gautham-kollu gautham-kollu requested review from deepakn94, gautham-kollu and jiemingz and removed request for deepakn94 March 23, 2026 16:47
@erhoo82 erhoo82 added complexity: low 26.04 this PR is high priority and should be merged asap Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. labels Mar 23, 2026
@asolergi-nv
Copy link
Copy Markdown
Contributor

/claude review

torch.cuda.synchronize()
torch.distributed.barrier()
logger.info(f'Optimizer CUDA graph capture done!!!')
if OptimizerCudaGraphWrapper.cuda_graph is None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Bug: On the capture iteration (curr_iteration == cuda_graph_warmup_steps), the optimizer step runs once during graph capture (line 45), then falls through to the else branch (line 52) which calls replay() — executing the optimizer step a second time. This will silently corrupt training on that iteration.

This if should be elif:

Suggested change
if OptimizerCudaGraphWrapper.cuda_graph is None:
elif OptimizerCudaGraphWrapper.cuda_graph is None:


import torch

from megatron.core.tensor_parallel.random import get_all_rng_states
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unused import — get_all_rng_states is never referenced in this file.

@asolergi-nv
Copy link
Copy Markdown
Contributor

/ok to test 5284819

@svcnvidia-nemo-ci svcnvidia-nemo-ci added this to the Core 0.16 milestone Mar 24, 2026
@@ -0,0 +1,70 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

"""Full iteration CUDA graph for training."""
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This description doesn't make much sense given the name of the file.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed it.

config_logger_dir: str = ""
"""When non-empty, dumps entry-point configs to config_logger_dir"""

on_device_clip_grad: bool = False
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a downside to doing this? Should we just default to on, or even just do it with no knob?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a commit to remove on_device_clip_grad knob and use on_device_clip_grad if the kernel is present.

@Phlip79 Phlip79 removed the Expert Review [deprecated] Apply this label to indicate that your PR is ready for expert review. label Mar 25, 2026
Copy link
Copy Markdown
Contributor

@maanug-nv maanug-nv left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. left 1 nit. approving to unblock, but please add the comment before merging.

cuda_graph_helper.delete_cuda_graphs()

if args.optimizer_cuda_graph:
del optimizer.step
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems a bit unintuitive , can we just add a comment explaining that it resets it back from the wrapper to the normal method impl ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added comment.

@ko3n1g
Copy link
Copy Markdown
Contributor

ko3n1g commented Mar 30, 2026

/ok to test 131c546

@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Mar 31, 2026

/ok to test b7066fc

@Phlip79 Phlip79 requested review from a team March 31, 2026 19:02
@@ -0,0 +1,61 @@
# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong year.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

@jiemingz jiemingz requested a review from a team April 1, 2026 12:02
@vasunvidia vasunvidia requested a review from jaredcasper April 1, 2026 21:23
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 2, 2026

/ok to test 83d686c

@svcnvidia-nemo-ci svcnvidia-nemo-ci added the Final Review PR is in the "final review" stage label Apr 2, 2026
@ko3n1g ko3n1g added the core_r0.17.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge. label Apr 2, 2026
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 2, 2026

/ok to test 180dd44

@svcnvidia-nemo-ci svcnvidia-nemo-ci added Approved All necessary approvals have been made and removed Final Review PR is in the "final review" stage labels Apr 3, 2026
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 3, 2026

/ok to test cc1e076

Use param_group.get('lr') instead of param_group['lr'] to avoid
KeyError on the first step() call from __init__, where param groups
may not yet have an 'lr' key.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
@Phlip79
Copy link
Copy Markdown
Member

Phlip79 commented Apr 3, 2026

/ok to test f0dd8ab

@Phlip79 Phlip79 added this pull request to the merge queue Apr 3, 2026
@svcnvidia-nemo-ci
Copy link
Copy Markdown

🔄 Merge queue validation started!

You can track the progress here: https://github.com/NVIDIA/Megatron-LM/actions/runs/23965168729

Merged via the queue into NVIDIA:main with commit 3d87bfc Apr 3, 2026
61 of 63 checks passed
ko3n1g pushed a commit that referenced this pull request Apr 3, 2026
Co-authored-by: Antoni-Joan Solergibert <asolergibert@nvidia.com>
Co-authored-by: Philip Petrakian <ppetrakian@nvidia.com>
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Signed-off-by: NeMo Bot <nemo-bot@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

26.04 this PR is high priority and should be merged asap Approved All necessary approvals have been made complexity: low core_r0.17.0 Auto-cherrypick to release branch. Apply before merge; cherrypick happens after merge.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

9 participants