From 96ef617a6650f1cab62c7e339cd8bb6d32e2985f Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 9 Dec 2025 06:23:04 +0000 Subject: [PATCH 01/15] [CI] Skipped test_gpt_full_activation_recompute tests for gfx950 --- tests/pytorch/test_numerics.py | 9 ++++++++- transformer_engine/pytorch/utils.py | 6 ++++++ 2 files changed, 14 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1787ab191..16c603422 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -28,7 +28,7 @@ is_bf16_compatible, ) if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi200, is_mi308 + from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi350 from transformer_engine.pytorch import ( DotProductAttention, @@ -757,6 +757,13 @@ def test_gpt_full_activation_recompute( pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if IS_HIP_EXTENSION and is_mi350(): + if (dtype == torch.bfloat16 + and not fp8 + and not use_reentrant + and recipe.float8_per_tensor_scaling() + ): + pytest.skip("hipBLASLt does not provide suitable algorithms on MI350 for this config.") config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 9d0d71fdc..f49b98cb4 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -456,6 +456,12 @@ def is_mi308(): import re return (re.search('AMD Instinct MI308', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) + @functools.lru_cache(maxsize=None) + def is_mi350(): + """check whether this machine is mi35x""" + props = torch.cuda.get_device_properties(torch.cuda.current_device()) + return (props.major, props.minor) == (9, 5) + @functools.lru_cache(maxsize=None) def is_fp8_fnuz(): return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) From 3da9fb3e3e72e75d00c9a55b915e94bc998ad564 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 9 Dec 2025 06:26:43 +0000 Subject: [PATCH 02/15] [CI] Skipped unsupported test_basic_linear_quantized tests on gfx950 --- tests/pytorch/test_fusible_ops.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 78894d97d..d8044e53b 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -36,6 +36,9 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine_torch as tex +from torch.utils.cpp_extension import IS_HIP_EXTENSION +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import is_mi350 # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -918,6 +921,16 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" + if IS_HIP_EXTENSION and is_mi350(): + if ( + quantization + and quantization.startswith("fp8") + and quantized_compute + and (quantized_grad_input or quantized_output) + ): + pytest.skip( + "hipBLASLt does not provide suitable algorithms on gfx950 for this config." + ) if quantization is None: pytest.skip("Skipping case without quantization") self._test_basic_linear( From 2dcf0c59935d7df01d0a7556fa6d274a779477da Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Wed, 10 Dec 2025 08:34:33 +0000 Subject: [PATCH 03/15] [CI] Fixed test_numerics, test_norms, test_fused_optimizer failures for gfx950 ci enablement --- tests/pytorch/test_fused_optimizer.py | 7 ++ tests/pytorch/test_numerics.py | 22 ++++- .../pytorch/triton_kernels/rmsnorm.py | 92 +++++++++++++++---- 3 files changed, 101 insertions(+), 20 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index e04f0477b..47e8820ed 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -9,6 +9,7 @@ import pytest import torch from torch import nn +from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling @@ -18,6 +19,9 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import gpu_autocast_ctx +if IS_HIP_EXTENSION: + from transformer_engine.pytorch.utils import is_mi350 + # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -378,6 +382,7 @@ def test_bf16_exp_avg(self): @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): + model_tol = 3e-2 if IS_HIP_EXTENSION and is_mi350() else None self.gen_precision_aware_test( use_fp8_params=False, param_dtype=torch.bfloat16, @@ -388,6 +393,8 @@ def test_fp8_exp_avg(self): exp_avg_sq_dtype=torch.float32, master_rtol=1e-2, master_atol=1e-2, + model_rtol=model_tol, + model_atol=model_tol, ) @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 16c603422..7662858d1 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2813,10 +2813,24 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): max_seqlen_kv=config.seq_len, ) - torch.testing.assert_close( - y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), - ) + if IS_HIP_EXTENSION: + tols_thd = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + # ROCm fused attention (CK) on THD can produce slightly larger error + tols_thd["atol"] = 2e-3 + _, use_aotriton, use_ck = rocm_attn_backend() + if use_aotriton and not use_ck: + tols_thd["rtol"] = tols_thd["rtol"] * 3 + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + **tols_thd, + ) + else: + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + ) @pytest.mark.parametrize("dtype", param_types) diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index c48a2a9b2..9f152582e 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -46,6 +46,8 @@ def _rmsnorm_fwd_triton_impl( IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, MAKE_TRANSPOSE: tl.constexpr, + INPUT_ALIGNED_16: tl.constexpr, + OUTPUT_ALIGNED_16: tl.constexpr, ): # Enable the transpose cache only in FP8 mode. @@ -78,7 +80,8 @@ def _rmsnorm_fwd_triton_impl( for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) sum_squares += tl.sum(x * x, axis=0) @@ -86,7 +89,8 @@ def _rmsnorm_fwd_triton_impl( cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) sum_squares += tl.sum(x * x, axis=0) @@ -101,7 +105,8 @@ def _rmsnorm_fwd_triton_impl( for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs).to(tl.float32) @@ -109,6 +114,8 @@ def _rmsnorm_fwd_triton_impl( g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -123,6 +130,8 @@ def _rmsnorm_fwd_triton_impl( cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols input_ptrs = row_input_ptr + cols + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) @@ -130,6 +139,8 @@ def _rmsnorm_fwd_triton_impl( g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -144,7 +155,8 @@ def _rmsnorm_fwd_triton_impl( mask = col_offsets < n_cols for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) row_norm = row * row @@ -160,7 +172,8 @@ def _rmsnorm_fwd_triton_impl( rms_norm = row * norm_factor * g output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets - output_ptrs = tl.multiple_of(output_ptrs, (16, )) + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -184,7 +197,9 @@ def _rmsnorm_fwd_triton_impl( @triton.jit def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, - USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr): + USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, + INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, + DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) # tl.assume(input_row_stride >= 0) @@ -209,8 +224,10 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d input_ptrs = row_input_ptr + cols grad_output_ptrs = row_grad_output_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) grad_output = tl.load(grad_output_ptrs).to(tl.float32) @@ -241,8 +258,10 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d input_ptrs = row_input_ptr + cols grad_output_ptrs = row_grad_output_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) grad_output = tl.load(grad_output_ptrs).to(tl.float32) @@ -255,10 +274,14 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d n_cols) dx_ptrs = row_dx_ptr + cols + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty)) dg = grad_output * x * norm_factor dg_ptrs = row_dg_ptr + cols + if DG_ALIGNED_16: + dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) tl.store(dg_ptrs, dg.to(tl.float32)) # Handle remainder @@ -277,10 +300,14 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d n_cols) dx_ptrs = row_dx_ptr + cols + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor dg_ptrs = row_dg_ptr + cols + if DG_ALIGNED_16: + dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) tl.store(dg_ptrs, dg.to(tl.float32), mask=mask) else: @@ -292,9 +319,12 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) - dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) @@ -352,9 +382,32 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, num_warps=8) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) + dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(-1) % 16 == 0) + dg_target = dg_tmp if need_reduction else dgamma + dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(-1) % 16 == 0) + _rmsnorm_bwd_triton[grid_bwd]( + dz_, + x_, + gamma_, + rsigma_, + dx, + dg_target, + x_.stride(0), + dz_.stride(0), + M, + N, + zero_centered_gamma, + blk_size, + USE_BLOCKED, + NUM_PRGMS, + input_aligned_16, + grad_output_aligned_16, + dx_aligned_16, + dg_aligned_16, + num_warps=8, + ) if need_reduction: grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] @@ -439,6 +492,11 @@ def te_rmsnorm_fwd_triton( grid_fwd = lambda meta: (NUM_PRGMS, ) # TODO(micky774) Implement fused MXFP8 quantization within the kernel kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl + input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(-1) % 16 == 0) + out_alignment_tensor = out._data if hasattr(out, "_data") else out + output_aligned_16 = (out_alignment_tensor.data_ptr() % 16 == 0) and ( + out_alignment_tensor.stride(-1) % 16 == 0 + ) kernel[grid_fwd]( out_ptr, input, @@ -460,6 +518,8 @@ def te_rmsnorm_fwd_triton( IS_FP8, FP8_MAX, MAKE_TRANSPOSE, + input_aligned_16, + output_aligned_16, ) if IS_MXFP8 or IS_FP8_CURRENT_SCALING: out = quantizer.quantize(out, out=ln_out) From fb4590d718e733c24184d7c021550e7ddb8a2d08 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 12 Dec 2025 09:41:21 +0000 Subject: [PATCH 04/15] [CI] Disabled gfx950 support until FP8 GEMM layout coverage is verified with hipblaslt --- transformer_engine/jax/quantize/device_utils.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/quantize/device_utils.py b/transformer_engine/jax/quantize/device_utils.py index ca90ba9fb..5fc0f0473 100644 --- a/transformer_engine/jax/quantize/device_utils.py +++ b/transformer_engine/jax/quantize/device_utils.py @@ -35,7 +35,8 @@ def get_device_compute_capability(gpu_id: int = 0) -> int: def is_fp8_gemm_with_all_layouts_supported() -> bool: """Return True if using Blackwell architecture, False otherwise.""" compute_capability = get_device_compute_capability() - if is_hip_extension(): + # Enable once FP8 GEMM layout coverage is validated with hipblaslt. + # if is_hip_extension(): # gfx950 --> NV blackwell - return compute_capability == 95 + # return compute_capability == 95 return 100 <= compute_capability < 120 From f4fa5148cca730512d2898b056a89c4891515eda Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 12 Dec 2025 22:17:53 +0000 Subject: [PATCH 05/15] [CI] [gfx950] Disable cudaGraph for gemmm and grouped-gemm --- transformer_engine/jax/csrc/extensions/gemm.cpp | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ba2d65e3e..ec61f2a97 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -5,6 +5,7 @@ ************************************************************************/ #include "transformer_engine/gemm.h" +#include #include #include #include @@ -21,6 +22,13 @@ namespace transformer_engine { namespace jax { +#ifdef USE_ROCM +// hipblaslt GEMM is not graph-capture safe on ROCm. +constexpr auto GemmFFI_CudaGraph_Traits = std::initializer_list{}; +#else +constexpr auto GemmFFI_CudaGraph_Traits = FFI_CudaGraph_Traits; +#endif + static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { // Move the pointer to the next 256B aligned address return reinterpret_cast((reinterpret_cast(ptr) + 255) & @@ -200,7 +208,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator"), - FFI_CudaGraph_Traits); + GemmFFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, @@ -593,7 +601,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad"), - FFI_CudaGraph_Traits); + GemmFFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine From 2396bb86e1a03f525341f725b4ede01433deb681 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Mon, 15 Dec 2025 23:14:37 +0000 Subject: [PATCH 06/15] Addressed reviews --- tests/pytorch/test_fused_optimizer.py | 6 ++---- tests/pytorch/test_fusible_ops.py | 5 ++--- tests/pytorch/test_numerics.py | 4 ++-- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 ++ transformer_engine/jax/quantize/device_utils.py | 3 ++- transformer_engine/pytorch/utils.py | 6 ------ 6 files changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 47e8820ed..32abea1de 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -18,9 +18,7 @@ from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import gpu_autocast_ctx - -if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi350 +from transformer_engine.pytorch.utils import get_device_compute_capability # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -382,7 +380,7 @@ def test_bf16_exp_avg(self): @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): - model_tol = 3e-2 if IS_HIP_EXTENSION and is_mi350() else None + model_tol = 3e-2 if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) else None self.gen_precision_aware_test( use_fp8_params=False, param_dtype=torch.bfloat16, diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index d8044e53b..1db81ec23 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -35,10 +35,9 @@ ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex from torch.utils.cpp_extension import IS_HIP_EXTENSION -if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi350 # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -921,7 +920,7 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" - if IS_HIP_EXTENSION and is_mi350(): + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): if ( quantization and quantization.startswith("fp8") diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 7662858d1..22f0ecb69 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -28,7 +28,7 @@ is_bf16_compatible, ) if IS_HIP_EXTENSION: - from transformer_engine.pytorch.utils import is_mi200, is_mi308, is_mi350 + from transformer_engine.pytorch.utils import is_mi200, is_mi308 from transformer_engine.pytorch import ( DotProductAttention, @@ -757,7 +757,7 @@ def test_gpt_full_activation_recompute( pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) - if IS_HIP_EXTENSION and is_mi350(): + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): if (dtype == torch.bfloat16 and not fp8 and not use_reentrant diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ec61f2a97..3e0842ef5 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. diff --git a/transformer_engine/jax/quantize/device_utils.py b/transformer_engine/jax/quantize/device_utils.py index 5fc0f0473..3f04d674c 100644 --- a/transformer_engine/jax/quantize/device_utils.py +++ b/transformer_engine/jax/quantize/device_utils.py @@ -36,7 +36,8 @@ def is_fp8_gemm_with_all_layouts_supported() -> bool: """Return True if using Blackwell architecture, False otherwise.""" compute_capability = get_device_compute_capability() # Enable once FP8 GEMM layout coverage is validated with hipblaslt. - # if is_hip_extension(): + if is_hip_extension(): # gfx950 --> NV blackwell # return compute_capability == 95 + return False return 100 <= compute_capability < 120 diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index f49b98cb4..9d0d71fdc 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -456,12 +456,6 @@ def is_mi308(): import re return (re.search('AMD Instinct MI308', torch.cuda.get_device_name(torch.cuda.current_device())) is not None) - @functools.lru_cache(maxsize=None) - def is_mi350(): - """check whether this machine is mi35x""" - props = torch.cuda.get_device_properties(torch.cuda.current_device()) - return (props.major, props.minor) == (9, 5) - @functools.lru_cache(maxsize=None) def is_fp8_fnuz(): return IS_HIP_EXTENSION and get_device_compute_capability() == (9, 4) From ab8a390ade7f9f1b3b364b96575db9057f517626 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 16 Dec 2025 16:18:22 +0000 Subject: [PATCH 07/15] [CI] Add MI355 nodes to github actions workflow --- .github/workflows/rocm-ci.yml | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 67af6dc9f..cbd8206f7 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -40,9 +40,12 @@ concurrency: jobs: build_and_test: - name: Build and Test on GPU + name: Build and Test on GPU (${{ matrix.runner }}) timeout-minutes: 720 - runs-on: linux-mi325-8 + runs-on: ${{ matrix.runner }} + strategy: + matrix: + runner: [linux-mi325-8, linux-mi355-8] steps: - name: Checkout repository uses: actions/checkout@v4 From d8da04eb3811b5e52425e8adbc638b0d03ed5aa5 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 16 Dec 2025 16:23:07 +0000 Subject: [PATCH 08/15] [CI] Update docker image --- ci/ci_config.json | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ci/ci_config.json b/ci/ci_config.json index 9ef4d03a2..cb5135817 100644 --- a/ci/ci_config.json +++ b/ci/ci_config.json @@ -1,6 +1,6 @@ { "docker_images": { - "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.0.2_ubuntu22.04_py3.10_pytorch_release-2.7_9015dfdf_jax_v0.6.0_fa-v2.8.0", + "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.6.0_fa-2.8.0", "release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273", "release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273" } From 3bcee1fb9b8324d7bdfb490369426b0a4d16449b Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 16 Dec 2025 16:46:36 +0000 Subject: [PATCH 09/15] [CI] add MI355 runner matrix and keep matrix legs independent --- .github/workflows/rocm-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index cbd8206f7..560a3ef53 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -44,6 +44,7 @@ jobs: timeout-minutes: 720 runs-on: ${{ matrix.runner }} strategy: + fail-fast: false matrix: runner: [linux-mi325-8, linux-mi355-8] steps: From d1894ef40899f4025594b87756a3f82342bb69fb Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Thu, 18 Dec 2025 06:09:32 +0000 Subject: [PATCH 10/15] Skip unstable Gemm tests on gfx950 --- tests/cpp/operator/test_cublaslt_gemm.cu | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 09d879efc..883ae5ab7 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -228,6 +228,14 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ + // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. + // Re-enable after ROCm 7.2 once hipBLASLt fixes land. + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; + } + // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only bool fp8_gelu_fusion_config = false; @@ -454,6 +462,14 @@ void performDqTest(const TestParams ¶ms) { cudaDeviceProp prop; (void)cudaGetDeviceProperties(&prop, 0); + // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. + // Re-enable after ROCm 7.2 once hipBLASLt fixes land. + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; + } + bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5); if (!mxfp8_supported) { GTEST_SKIP() << "MXFP8 is not supported in current config"; From b4d8c8fff8fe76b10a00e07dfbb9d8322436c7a8 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Thu, 18 Dec 2025 17:42:50 +0000 Subject: [PATCH 11/15] Addressed reviews --- ci/ci_config.json | 2 +- tests/pytorch/test_fused_optimizer.py | 2 ++ tests/pytorch/test_numerics.py | 11 ++++------- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/ci/ci_config.json b/ci/ci_config.json index cb5135817..a7b3d5d6c 100644 --- a/ci/ci_config.json +++ b/ci/ci_config.json @@ -1,6 +1,6 @@ { "docker_images": { - "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.6.0_fa-2.8.0", + "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0", "release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273", "release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273" } diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index 32abea1de..3527bb9a6 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 22f0ecb69..b2885d677 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2813,14 +2813,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): max_seqlen_kv=config.seq_len, ) - if IS_HIP_EXTENSION: + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): tols_thd = dtype_tols(dtype) - if dtype in (torch.float16, torch.bfloat16): - # ROCm fused attention (CK) on THD can produce slightly larger error - tols_thd["atol"] = 2e-3 - _, use_aotriton, use_ck = rocm_attn_backend() - if use_aotriton and not use_ck: - tols_thd["rtol"] = tols_thd["rtol"] * 3 + # On gfx950 the results for THD are different + # that results in lower final result precision + tols_thd["atol"] = 2e-3 torch.testing.assert_close( y_bshd, y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), From 8aca8e0f24588aed997d846497ee0905dfe5d90f Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Sat, 20 Dec 2025 00:22:52 +0000 Subject: [PATCH 12/15] Guard gfx950 TN skip by ROCm version and adjust MXFP8 Dq test size --- tests/cpp/operator/test_cublaslt_gemm.cu | 26 ++++++++---------------- 1 file changed, 9 insertions(+), 17 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 883ae5ab7..61ca86a1e 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -29,8 +29,8 @@ std::vector> test_case_sizes = { }; std::vector> test_case_sizes_mxfp8 = { - {2304, 768, 4096}, -}; + {768, 3072, 4096}, +}; // A, B, Bias, Gelu, D // Bias type choose as bf16 in use_fp8, D_type otherwise @@ -228,13 +228,13 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ - // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. - // Re-enable after ROCm 7.2 once hipBLASLt fixes land. - if (prop.major == 9 && prop.minor == 5 && - params.transa && !params.transb && - params.m == 2304 && params.k == 768 && params.n == 4096) { - GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; - } + #if HIP_VERSION < 70200000 + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2"; + } + #endif // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only @@ -462,14 +462,6 @@ void performDqTest(const TestParams ¶ms) { cudaDeviceProp prop; (void)cudaGetDeviceProperties(&prop, 0); - // Temporary skip: gfx950 TN kernels for (M,K,N)=(2304,768,4096) are unstable. - // Re-enable after ROCm 7.2 once hipBLASLt fixes land. - if (prop.major == 9 && prop.minor == 5 && - params.transa && !params.transb && - params.m == 2304 && params.k == 768 && params.n == 4096) { - GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 until ROCm 7.2"; - } - bool mxfp8_supported = (prop.major == 9 && prop.minor >= 5); if (!mxfp8_supported) { GTEST_SKIP() << "MXFP8 is not supported in current config"; From 93c118b60578b90865e1739bc9d5f6546819a5db Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Tue, 6 Jan 2026 16:43:10 +0000 Subject: [PATCH 13/15] Removed ROCM7.2 guards --- tests/cpp/operator/test_cublaslt_gemm.cu | 8 -------- 1 file changed, 8 deletions(-) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 61ca86a1e..90364d4bc 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -228,14 +228,6 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ - #if HIP_VERSION < 70200000 - if (prop.major == 9 && prop.minor == 5 && - params.transa && !params.transb && - params.m == 2304 && params.k == 768 && params.n == 4096) { - GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2"; - } - #endif - // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only bool fp8_gelu_fusion_config = false; From b551b3f519c5e29b5ea904f4ac60766bfc0aad69 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Wed, 7 Jan 2026 16:21:10 +0000 Subject: [PATCH 14/15] Reverted ROCM7.2 guards --- tests/cpp/operator/test_cublaslt_gemm.cu | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 90364d4bc..61ca86a1e 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -228,6 +228,14 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ + #if HIP_VERSION < 70200000 + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2"; + } + #endif + // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only bool fp8_gelu_fusion_config = false; From b52594b4dd9ee6b97a53c2628c3d9fbc90e317d4 Mon Sep 17 00:00:00 2001 From: Veera Rajasekhar Reddy Gopu Date: Fri, 9 Jan 2026 01:47:10 -0600 Subject: [PATCH 15/15] Update rocm-ci.yml --- .github/workflows/rocm-ci.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 560a3ef53..4fc34e391 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -426,7 +426,7 @@ jobs: if: always() uses: actions/upload-artifact@v4 with: - name: logs-and-reports + name: logs-and-reports-${{ matrix.runner }} path: | *.log if-no-files-found: ignore