[JAX] Refactor and trim TE JAX Attn testing #2542
Merged
+162
−24
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Reduce the number of test cases for
Type of change
Changes
1A] Current tests run all possible combinations for L1 and L2 where
dp*cp*tp <= num gpus.Below is an example of the L1 dist timing for dist fused attn tests using TE 2.11 (B200x8)
1B] This PR runs only those L1 and L2 combinations where
dp*cp*tp==num gpus.Below is an example of the L1 dist timing for dist fused attn tests using this PR (B200x8)
There is a reduction of 1020 (2157-1137) tests collected owing to the change in this PR.
For the
test_context_paralleltests, the number of test are halved in number in this PR as only the test cases fordp*cp*tp==8are collected but not those fordp*cp*tp==4anddp*cp*tp==2. This is not that big a problem in CI as we run H100x4 and GB200x4 so test cases fordp*cp*tp==4will be covered in there.1C] Changes in testing due of this PR:
dp*cp*tp==2test will not be covered. If coverage for this is needed in the future, CI could setCUDA_VISIBLE_DEVICES=0,1for any of these configs and run these tests as wellWith current CI configs,
dp*cp*tp<=8for B200, however, with this PR, we will only rundp*cp*tp==8casesdp*cp*tp<=4for H100, however, with this PR, we will only rundp*cp*tp==4casesdp*cp*tp<=4for GB200, however, with this PR, we will only rundp*cp*tp==4cases.Overall test cases would still be the same but we just would not have all combinations available for a given CI config (GPU arch) running on it
2] Consolidation of fused attn tests which in turn reduces the number of L0 tests run for CI
Below is an example of the L0 timing for fused attn tests using TE 2.11(H100)
Below is an example of the L0 timing for fused attn tests using this PR (H100)
Checklist: