Skip to content

Conversation

@Micky774
Copy link
Contributor

@Micky774 Micky774 commented Nov 5, 2025

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@ipanfilo
Copy link
Collaborator

ipanfilo commented Nov 6, 2025

Which functionality not covered by existing tests does it cover?

@wangye805
Copy link
Collaborator

Which functionality not covered by existing tests does it cover?

Previously our jax and pytorch distributed fused-attn only enables v2 ck backends, not v3

@ipanfilo
Copy link
Collaborator

ipanfilo commented Nov 6, 2025

Which functionality not covered by existing tests does it cover?

Previously our jax and pytorch distributed fused-attn only enables v2 ck backends, not v3

Yes, but does it run different fused attn backend configs/kernels then non distributed ones? Or there is functionality concern of coexistence of them with RCCL?

@wangye805
Copy link
Collaborator

Which functionality not covered by existing tests does it cover?

Previously our jax and pytorch distributed fused-attn only enables v2 ck backends, not v3

Yes, but does it run different fused attn backend configs/kernels then non distributed ones? Or there is functionality concern of coexistence of them with RCCL?

In the distributed fused-attn (CP) pytest suite, the reference run is usually a single-GPU fused-attn with full seqlen (for example sq=skv=8192) using the default attn backend. The target run decomposes the single full-size fused-attn into 4 or 8 smaller fused-attn (for example, sq=sk=4096), runs those smaller fused-attn instances using the default backend and then "glue" the results in the CP way.

In my option, why we need to enable v3 for distributed fused-attn:
1). The decomposition may create new fused-attn configs not covered by our single GPU fused-attn pytests
2). The "glue" process actually tested the softmax_lse generated in fwd pass. If the softmax lse results are not correct, the glued results will be wrong. And our single GPU fused-attn pytest do not test softmax_lse at all.

ci/jax.sh Outdated
*0.4.35*)
# Workaround for distributed tests hang with xla_flag
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will run it with AOTriton too

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated with a guard in the JAX ci script

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With those changes env variables are not seen by run method - they are applied to test call only.
Using run_default_fa_lbl. All V3 calls should be labelled with "v3" to distinct them from regular test_distributed_fused_attn call

@wenchenvincent
Copy link
Collaborator

@Micky774 Could you rebase upon latest dev to incorporate the hot fix for the core sgpu tests?

@wenchenvincent
Copy link
Collaborator

@ipanfilo Could you check if all your comments have been addressed?

ci/jax.sh Outdated
*0.4.35*)
# Workaround for distributed tests hang with xla_flag
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" NVTE_CK_USES_FWD_V3=1 NVTE_CK_USES_BWD_V3=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With those changes env variables are not seen by run method - they are applied to test call only.
Using run_default_fa_lbl. All V3 calls should be labelled with "v3" to distinct them from regular test_distributed_fused_attn call

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants