Skip to content

Conversation

@KshitijLakhani
Copy link
Collaborator

@KshitijLakhani KshitijLakhani commented Dec 24, 2025

Description

Reduce the number of test cases for

  1. L0 JAX fused attn
  2. L1 JAX dist fused attn
  3. L2 JAX dist fused attn

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

1A] Current tests run all possible combinations for L1 and L2 where dp*cp*tp <= num gpus.

Below is an example of the L1 dist timing for dist fused attn tests using TE 2.11 (B200x8)

2157 collected
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  27x |    2.84s | avg:   0.11s
test_context_parallel_allgather_attn                         | 320x |  917.64s | avg:   2.87s
test_context_parallel_allgather_attn_shardy                  |  80x |  503.19s | avg:   6.29s
test_context_parallel_allgather_striped_attn                 | 320x |  338.28s | avg:   1.06s
test_context_parallel_ring_attn                              | 1280x | 1662.10s | avg:   1.30s
test_context_parallel_ring_attn_shardy                       |  40x |   46.24s | avg:   1.16s
test_cross_attn                                              |  18x |   31.73s | avg:   1.76s
test_self_attn                                               |  54x |  125.43s | avg:   2.32s
test_self_attn_shardy                                        |  18x |   16.87s | avg:   0.94s
================================================================================
TOTAL RUNTIME                                                |      | 3644.31s |
================================================================================

1B] This PR runs only those L1 and L2 combinations where dp*cp*tp==num gpus.

Below is an example of the L1 dist timing for dist fused attn tests using this PR (B200x8)

1137 collected
================================================================================
TEST RUNTIME SUMMARY (grouped by function)
================================================================================
test                                                         |  27x |    2.37s | avg:   0.09s
test_context_parallel_allgather_attn                         | 160x |  344.81s | avg:   2.16s
test_context_parallel_allgather_attn_shardy                  |  40x |  167.59s | avg:   4.19s
test_context_parallel_allgather_striped_attn                 | 160x |  190.06s | avg:   1.19s
test_context_parallel_ring_attn                              | 640x |  729.42s | avg:   1.14s
test_context_parallel_ring_attn_shardy                       |  20x |   16.94s | avg:   0.85s
test_cross_attn                                              |  18x |   23.38s | avg:   1.30s
test_self_attn                                               |  54x |  118.77s | avg:   2.20s
test_self_attn_shardy                                        |  18x |   12.75s | avg:   0.71s
================================================================================
TOTAL RUNTIME                                                |      | 1606.09s |
================================================================================

There is a reduction of 1020 (2157-1137) tests collected owing to the change in this PR.
For the test_context_parallel tests, the number of test are halved in number in this PR as only the test cases for dp*cp*tp==8 are collected but not those for dp*cp*tp==4 and dp*cp*tp==2. This is not that big a problem in CI as we run H100x4 and GB200x4 so test cases for dp*cp*tp==4 will be covered in there.

1C] Changes in testing due of this PR:

  • dp*cp*tp==2 test will not be covered. If coverage for this is needed in the future, CI could set CUDA_VISIBLE_DEVICES=0,1 for any of these configs and run these tests as well

With current CI configs,

  • the current tests would run dp*cp*tp<=8 for B200, however, with this PR, we will only run dp*cp*tp==8 cases
  • the current tests would run dp*cp*tp<=4 for H100, however, with this PR, we will only run dp*cp*tp==4 cases
  • the current tests would run dp*cp*tp<=4 for GB200, however, with this PR, we will only run dp*cp*tp==4 cases.

Overall test cases would still be the same but we just would not have all combinations available for a given CI config (GPU arch) running on it

2] Consolidation of fused attn tests which in turn reduces the number of L0 tests run for CI

  • The 15120 collected tests for all possible data shape and qkv layout combinations in current fused attn tests is reduced to 5040 collected tests by using only very specific combinations of data shape and qkv layout

Below is an example of the L0 timing for fused attn tests using TE 2.11(H100)

test_backward                                                | 15204x | 1569.98s | avg:   0.10s

Below is an example of the L0 timing for fused attn tests using this PR (H100)

test_backward                                                | 5124x |  860.32s | avg:   0.17s

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

…only those cp,dp,tp combinations are picked where cp*dp*tp is equal to num gpus

Signed-off-by: Kshitij Lakhani <[email protected]>
@KshitijLakhani KshitijLakhani changed the title Refactor and trim TE JAX Attn testing [JAX] Refactor and trim TE JAX Attn testing Dec 24, 2025
@KshitijLakhani KshitijLakhani self-assigned this Dec 24, 2025
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/reduce-test-time branch from d2f9634 to d6a2951 Compare January 8, 2026 01:30
@KshitijLakhani KshitijLakhani force-pushed the klakhani/maint/reduce-test-time branch from 1b12ca4 to 4dbd600 Compare January 8, 2026 21:50
@KshitijLakhani KshitijLakhani marked this pull request as ready for review January 8, 2026 21:53
@KshitijLakhani
Copy link
Collaborator Author

Pipeline: 41376441 passes
Confirmed the number of tests expected and reduction in time (as expected)

Copy link
Collaborator

@jberchtold-nvidia jberchtold-nvidia left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks! Awesome improvement in CI runtime!

@KshitijLakhani KshitijLakhani merged commit 5f0e3b9 into NVIDIA:main Jan 9, 2026
12 checks passed
@KshitijLakhani KshitijLakhani deleted the klakhani/maint/reduce-test-time branch January 9, 2026 00:32
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants