Skip to content

Conversation

@VeeraRajasekhar
Copy link
Contributor

Description

  • Enable gfx950 (MI350) CI by addressing the specific failures we saw: FP8 GEMM coverage gaps in hipBLASLt, RMSNorm misalignment on odd strides (e.g., N=17389), fused optimizer tolerances, and unsupported quantized/activation-recompute test cases on ROCm.
  • Prevent JAX GEMM/grouped-GEMM FFI from being marked cudaGraph-safe on ROCm to avoid failures; keep gfx950 FP8 layout support disabled until hipBLASLt coverage is validated.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • [] New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Disable cudaGraph registration for JAX gemm and grouped_gemm FFI on ROCm to stop graph-capture hangs for gfx950 (transformer_engine/jax/csrc/extensions/gemm.cpp).

  • Keep is_fp8_gemm_with_all_layouts_supported false on gfx950 until hipBLASLt FP8 layout coverage is validated (transformer_engine/jax/quantize/device_utils.py).

  • Fix RMSNorm Triton kernel for misaligned row strides by only applying 16B alignment hints when the pointers/strides are aligned; this resolves test_norms dgamma mismatches and the test_transformer_layer_hidden_states_format numerics issues. Also relax fused-optimizer FP8 tolerances on MI350 (transformer_engine/pytorch/triton_kernels/rmsnorm.py, tests/pytorch/test_numerics.py, tests/pytorch/test_fused_optimizer.py).

  • Skip unsupported FP8 quantized linear combinations on gfx950 where hipBLASLt lacks algorithms (tests/pytorch/test_fusible_ops.py).

  • Add gfx950 detection helper and skip test_gpt_full_activation_recompute on MI350 configs that hipBLASLt cannot serve (transformer_engine/pytorch/utils.py, tests/pytorch/test_numerics.py).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes


#ifdef __HIP_PLATFORM_AMD__

// Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What does unstable mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

6192 - OperatorTest/GEMMTestSuite.Testfp8xfp8xbf16xbf16xbf16/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
6768 - OperatorTest/GEMMTestSuite.Testfp8xbf8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7344 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp32/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)
7488 - OperatorTest/GEMMTestSuite.Testbf8xfp8xbf16xbf16xfp16/2304x768x4096x0x0xTNxM  # GetParam() = ((2304, 768, 4096), false, false, (true, false), 1) (Failed)

These testcases are failing at random, so we decided to skip for this mi350 bring up. When I tested on Rocm7.2 there was no issue

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed


#ifdef __HIP_PLATFORM_AMD__

// Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Guard it with #if HIP_VERSION < 70200000 then. So comments about temporary disable and re-enable and mentioning of ROCm 7.2 can be removed

// Re-enable after ROCm 7.2 once hipBLASLt fixes land.
if (prop.major == 9 && prop.minor == 5 &&
params.transa && !params.transb &&
params.m == 2304 && params.k == 768 && params.n == 4096) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is only 1 size for DqTest. Instead of skipping the test just use different size for test_case_sizes_mxfp8, for example 768, 3072, 4096

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

@VeeraRajasekhar
Copy link
Contributor Author

rebase to dev

@VeeraRajasekhar
Copy link
Contributor Author

Test report for MI355 with Level=3:

  • No issues with sgpu tests reported.
  • Pytorch Mgpu tests had no issues
  • Jax test [auto] test_distributed_fused_attn.py timeout is triggered due to hang which is known. Other Jax tests passed

@VeeraRajasekhar VeeraRajasekhar merged commit f141f34 into dev Jan 9, 2026
2 of 4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants