diff --git a/.github/workflows/rocm-ci.yml b/.github/workflows/rocm-ci.yml index 67af6dc9f..4fc34e391 100644 --- a/.github/workflows/rocm-ci.yml +++ b/.github/workflows/rocm-ci.yml @@ -40,9 +40,13 @@ concurrency: jobs: build_and_test: - name: Build and Test on GPU + name: Build and Test on GPU (${{ matrix.runner }}) timeout-minutes: 720 - runs-on: linux-mi325-8 + runs-on: ${{ matrix.runner }} + strategy: + fail-fast: false + matrix: + runner: [linux-mi325-8, linux-mi355-8] steps: - name: Checkout repository uses: actions/checkout@v4 @@ -422,7 +426,7 @@ jobs: if: always() uses: actions/upload-artifact@v4 with: - name: logs-and-reports + name: logs-and-reports-${{ matrix.runner }} path: | *.log if-no-files-found: ignore diff --git a/ci/ci_config.json b/ci/ci_config.json index 9ef4d03a2..a7b3d5d6c 100644 --- a/ci/ci_config.json +++ b/ci/ci_config.json @@ -1,6 +1,6 @@ { "docker_images": { - "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.0.2_ubuntu22.04_py3.10_pytorch_release-2.7_9015dfdf_jax_v0.6.0_fa-v2.8.0", + "default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0", "release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273", "release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273" } diff --git a/tests/cpp/operator/test_cublaslt_gemm.cu b/tests/cpp/operator/test_cublaslt_gemm.cu index 09d879efc..61ca86a1e 100644 --- a/tests/cpp/operator/test_cublaslt_gemm.cu +++ b/tests/cpp/operator/test_cublaslt_gemm.cu @@ -29,8 +29,8 @@ std::vector> test_case_sizes = { }; std::vector> test_case_sizes_mxfp8 = { - {2304, 768, 4096}, -}; + {768, 3072, 4096}, +}; // A, B, Bias, Gelu, D // Bias type choose as bf16 in use_fp8, D_type otherwise @@ -228,6 +228,14 @@ void performTest(const TestParams& params) { #ifdef __HIP_PLATFORM_AMD__ + #if HIP_VERSION < 70200000 + if (prop.major == 9 && prop.minor == 5 && + params.transa && !params.transb && + params.m == 2304 && params.k == 768 && params.n == 4096) { + GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2"; + } + #endif + // Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0. // hipBLASLt currently supports this config only bool fp8_gelu_fusion_config = false; diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index e04f0477b..3527bb9a6 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -9,6 +11,7 @@ import pytest import torch from torch import nn +from torch.utils.cpp_extension import IS_HIP_EXTENSION from torch.testing._internal.common_device_type import largeTensorTest import transformer_engine.pytorch as te from transformer_engine.common.recipe import DelayedScaling @@ -17,6 +20,7 @@ from transformer_engine.pytorch.utils import is_bf16_compatible from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import gpu_autocast_ctx +from transformer_engine.pytorch.utils import get_device_compute_capability # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -378,6 +382,7 @@ def test_bf16_exp_avg(self): @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_exp_avg(self): + model_tol = 3e-2 if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) else None self.gen_precision_aware_test( use_fp8_params=False, param_dtype=torch.bfloat16, @@ -388,6 +393,8 @@ def test_fp8_exp_avg(self): exp_avg_sq_dtype=torch.float32, master_rtol=1e-2, master_atol=1e-2, + model_rtol=model_tol, + model_atol=model_tol, ) @pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported") diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 78894d97d..1db81ec23 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -35,7 +35,9 @@ ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer from transformer_engine.pytorch.utils import is_bf16_compatible +from transformer_engine.pytorch.utils import get_device_compute_capability import transformer_engine_torch as tex +from torch.utils.cpp_extension import IS_HIP_EXTENSION # Import utility functions _current_file = pathlib.Path(__file__).resolve() @@ -918,6 +920,16 @@ def test_basic_linear_quantized( quantized_grad_input: bool, ) -> None: """GEMM with FP8 inputs and outputs""" + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): + if ( + quantization + and quantization.startswith("fp8") + and quantized_compute + and (quantized_grad_input or quantized_output) + ): + pytest.skip( + "hipBLASLt does not provide suitable algorithms on gfx950 for this config." + ) if quantization is None: pytest.skip("Skipping case without quantization") self._test_basic_linear( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 1787ab191..b2885d677 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -757,6 +757,13 @@ def test_gpt_full_activation_recompute( pytest.skip("FP8 parameters are not supported in debug mode.") if recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): + if (dtype == torch.bfloat16 + and not fp8 + and not use_reentrant + and recipe.float8_per_tensor_scaling() + ): + pytest.skip("hipBLASLt does not provide suitable algorithms on MI350 for this config.") config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow @@ -2806,10 +2813,21 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): max_seqlen_kv=config.seq_len, ) - torch.testing.assert_close( - y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), - ) + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): + tols_thd = dtype_tols(dtype) + # On gfx950 the results for THD are different + # that results in lower final result precision + tols_thd["atol"] = 2e-3 + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + **tols_thd, + ) + else: + torch.testing.assert_close( + y_bshd, + y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), + ) @pytest.mark.parametrize("dtype", param_types) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ba2d65e3e..3e0842ef5 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -1,10 +1,13 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ #include "transformer_engine/gemm.h" +#include #include #include #include @@ -21,6 +24,13 @@ namespace transformer_engine { namespace jax { +#ifdef USE_ROCM +// hipblaslt GEMM is not graph-capture safe on ROCm. +constexpr auto GemmFFI_CudaGraph_Traits = std::initializer_list{}; +#else +constexpr auto GemmFFI_CudaGraph_Traits = FFI_CudaGraph_Traits; +#endif + static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { // Move the pointer to the next 256B aligned address return reinterpret_cast((reinterpret_cast(ptr) + 255) & @@ -200,7 +210,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Attr("fuse_gelu") .Attr("grad") .Attr("use_split_accumulator"), - FFI_CudaGraph_Traits); + GemmFFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, @@ -593,7 +603,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("scaling_mode") .Attr("has_bias") .Attr("is_grouped_dense_wgrad"), - FFI_CudaGraph_Traits); + GemmFFI_CudaGraph_Traits); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/quantize/device_utils.py b/transformer_engine/jax/quantize/device_utils.py index ca90ba9fb..3f04d674c 100644 --- a/transformer_engine/jax/quantize/device_utils.py +++ b/transformer_engine/jax/quantize/device_utils.py @@ -35,7 +35,9 @@ def get_device_compute_capability(gpu_id: int = 0) -> int: def is_fp8_gemm_with_all_layouts_supported() -> bool: """Return True if using Blackwell architecture, False otherwise.""" compute_capability = get_device_compute_capability() + # Enable once FP8 GEMM layout coverage is validated with hipblaslt. if is_hip_extension(): # gfx950 --> NV blackwell - return compute_capability == 95 + # return compute_capability == 95 + return False return 100 <= compute_capability < 120 diff --git a/transformer_engine/pytorch/triton_kernels/rmsnorm.py b/transformer_engine/pytorch/triton_kernels/rmsnorm.py index c48a2a9b2..9f152582e 100644 --- a/transformer_engine/pytorch/triton_kernels/rmsnorm.py +++ b/transformer_engine/pytorch/triton_kernels/rmsnorm.py @@ -46,6 +46,8 @@ def _rmsnorm_fwd_triton_impl( IS_FP8: tl.constexpr, FP8_MAX: tl.constexpr, MAKE_TRANSPOSE: tl.constexpr, + INPUT_ALIGNED_16: tl.constexpr, + OUTPUT_ALIGNED_16: tl.constexpr, ): # Enable the transpose cache only in FP8 mode. @@ -78,7 +80,8 @@ def _rmsnorm_fwd_triton_impl( for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) sum_squares += tl.sum(x * x, axis=0) @@ -86,7 +89,8 @@ def _rmsnorm_fwd_triton_impl( cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) sum_squares += tl.sum(x * x, axis=0) @@ -101,7 +105,8 @@ def _rmsnorm_fwd_triton_impl( for blk_idx in tl.range(0, n_cols_blks, num_stages=2): cols = blk_idx * BLOCK_SIZE + col_offsets input_ptrs = row_input_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs).to(tl.float32) @@ -109,6 +114,8 @@ def _rmsnorm_fwd_triton_impl( g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -123,6 +130,8 @@ def _rmsnorm_fwd_triton_impl( cols = n_cols_blks * BLOCK_SIZE + col_offsets mask = cols < n_cols input_ptrs = row_input_ptr + cols + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) g_ptrs = g_ptr + cols g = tl.load(g_ptrs, mask=mask, other=0.0).to(tl.float32) @@ -130,6 +139,8 @@ def _rmsnorm_fwd_triton_impl( g += 1 rms_norm = x * norm_factor * g output_ptrs = row_output_ptr + cols + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -144,7 +155,8 @@ def _rmsnorm_fwd_triton_impl( mask = col_offsets < n_cols for row_idx in tl.range(row_start, n_rows, NUM_PRGMS, num_stages=2): input_ptrs = input_ptr + row_idx * input_row_stride + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) row = tl.load(input_ptrs, mask=mask, other=0.0, cache_modifier=".cg").to(tl.float32) g = tl.load(g_ptr + col_offsets, mask=mask, other=0.0).to(tl.float32) row_norm = row * row @@ -160,7 +172,8 @@ def _rmsnorm_fwd_triton_impl( rms_norm = row * norm_factor * g output_ptrs = output_ptr + row_idx * output_row_stride + col_offsets - output_ptrs = tl.multiple_of(output_ptrs, (16, )) + if OUTPUT_ALIGNED_16: + output_ptrs = tl.multiple_of(output_ptrs, (16, )) if IS_FP8: amax_temp = tl.max(tl.abs(rms_norm), axis=-1) amax = tl.maximum(amax, amax_temp) @@ -184,7 +197,9 @@ def _rmsnorm_fwd_triton_impl( @triton.jit def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, dg_ptr, input_row_stride, output_row_stride, n_rows, n_cols, ZERO_CENTERED_GAMMA: tl.constexpr, BLOCK_SIZE: tl.constexpr, - USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr): + USE_BLOCKED: tl.constexpr, NUM_PRGMS: tl.constexpr, + INPUT_ALIGNED_16: tl.constexpr, GRAD_OUTPUT_ALIGNED_16: tl.constexpr, + DX_ALIGNED_16: tl.constexpr, DG_ALIGNED_16: tl.constexpr): row_start = tl.program_id(0) col_offsets = tl.arange(0, BLOCK_SIZE) # tl.assume(input_row_stride >= 0) @@ -209,8 +224,10 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d input_ptrs = row_input_ptr + cols grad_output_ptrs = row_grad_output_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) grad_output = tl.load(grad_output_ptrs).to(tl.float32) @@ -241,8 +258,10 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d input_ptrs = row_input_ptr + cols grad_output_ptrs = row_grad_output_ptr + cols - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) x = tl.load(input_ptrs).to(tl.float32) grad_output = tl.load(grad_output_ptrs).to(tl.float32) @@ -255,10 +274,14 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d n_cols) dx_ptrs = row_dx_ptr + cols + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty)) dg = grad_output * x * norm_factor dg_ptrs = row_dg_ptr + cols + if DG_ALIGNED_16: + dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) tl.store(dg_ptrs, dg.to(tl.float32)) # Handle remainder @@ -277,10 +300,14 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d n_cols) dx_ptrs = row_dx_ptr + cols + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) tl.store(dx_ptrs, grad_input.to(dx_ptr.type.element_ty), mask=mask) dg = grad_output * x * norm_factor dg_ptrs = row_dg_ptr + cols + if DG_ALIGNED_16: + dg_ptrs = tl.multiple_of(dg_ptrs, (16, )) tl.store(dg_ptrs, dg.to(tl.float32), mask=mask) else: @@ -292,9 +319,12 @@ def _rmsnorm_bwd_triton(grad_output_ptr, input_ptr, g_ptr, rsigma_ptr, dx_ptr, d grad_output_ptrs = grad_output_ptr + row_idx * output_row_stride + col_offsets dx_ptrs = dx_ptr + row_idx * input_row_stride + col_offsets - input_ptrs = tl.multiple_of(input_ptrs, (16, )) - grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) - dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) + if INPUT_ALIGNED_16: + input_ptrs = tl.multiple_of(input_ptrs, (16, )) + if GRAD_OUTPUT_ALIGNED_16: + grad_output_ptrs = tl.multiple_of(grad_output_ptrs, (16, )) + if DX_ALIGNED_16: + dx_ptrs = tl.multiple_of(dx_ptrs, (16, )) x = tl.load(input_ptrs, mask=mask, other=0.0).to(tl.float32) grad_output = tl.load(grad_output_ptrs, mask=mask, other=0.0).to(tl.float32) @@ -352,9 +382,32 @@ def te_rmsnorm_bwd_triton(dz, x, rsigma, gamma, sm_margin, zero_centered_gamma): dg_tmp = torch.empty(dg_tmp_rows(x_, sm_margin), N, device=x.device, dtype=torch.float32, requires_grad=False) if need_reduction else None grid_bwd = lambda meta: (NUM_PRGMS, ) - _rmsnorm_bwd_triton[grid_bwd](dz_, x_, gamma_, rsigma_, dx, dg_tmp if need_reduction else dgamma, - x_.stride(0), dz_.stride(0), M, N, zero_centered_gamma, blk_size, - USE_BLOCKED, NUM_PRGMS, num_warps=8) + input_aligned_16 = (x_.data_ptr() % 16 == 0) and (x_.stride(-1) % 16 == 0) + grad_output_aligned_16 = (dz_.data_ptr() % 16 == 0) and (dz_.stride(-1) % 16 == 0) + dx_aligned_16 = (dx.data_ptr() % 16 == 0) and (dx.stride(-1) % 16 == 0) + dg_target = dg_tmp if need_reduction else dgamma + dg_aligned_16 = (dg_target.data_ptr() % 16 == 0) and (dg_target.stride(-1) % 16 == 0) + _rmsnorm_bwd_triton[grid_bwd]( + dz_, + x_, + gamma_, + rsigma_, + dx, + dg_target, + x_.stride(0), + dz_.stride(0), + M, + N, + zero_centered_gamma, + blk_size, + USE_BLOCKED, + NUM_PRGMS, + input_aligned_16, + grad_output_aligned_16, + dx_aligned_16, + dg_aligned_16, + num_warps=8, + ) if need_reduction: grid_reduce = lambda meta: [triton.cdiv(N, meta['BLOCK_SIZE_N'])] @@ -439,6 +492,11 @@ def te_rmsnorm_fwd_triton( grid_fwd = lambda meta: (NUM_PRGMS, ) # TODO(micky774) Implement fused MXFP8 quantization within the kernel kernel = _rmsnorm_fwd_triton if autotune else _rmsnorm_fwd_triton_impl + input_aligned_16 = (input.data_ptr() % 16 == 0) and (input.stride(-1) % 16 == 0) + out_alignment_tensor = out._data if hasattr(out, "_data") else out + output_aligned_16 = (out_alignment_tensor.data_ptr() % 16 == 0) and ( + out_alignment_tensor.stride(-1) % 16 == 0 + ) kernel[grid_fwd]( out_ptr, input, @@ -460,6 +518,8 @@ def te_rmsnorm_fwd_triton( IS_FP8, FP8_MAX, MAKE_TRANSPOSE, + input_aligned_16, + output_aligned_16, ) if IS_MXFP8 or IS_FP8_CURRENT_SCALING: out = quantizer.quantize(out, out=ln_out)