Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions .github/workflows/rocm-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,13 @@ concurrency:

jobs:
build_and_test:
name: Build and Test on GPU
name: Build and Test on GPU (${{ matrix.runner }})
timeout-minutes: 720
runs-on: linux-mi325-8
runs-on: ${{ matrix.runner }}
strategy:
fail-fast: false
matrix:
runner: [linux-mi325-8, linux-mi355-8]
steps:
- name: Checkout repository
uses: actions/checkout@v4
Expand Down Expand Up @@ -422,7 +426,7 @@ jobs:
if: always()
uses: actions/upload-artifact@v4
with:
name: logs-and-reports
name: logs-and-reports-${{ matrix.runner }}
path: |
*.log
if-no-files-found: ignore
Expand Down
2 changes: 1 addition & 1 deletion ci/ci_config.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"docker_images": {
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.0.2_ubuntu22.04_py3.10_pytorch_release-2.7_9015dfdf_jax_v0.6.0_fa-v2.8.0",
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0",
"release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273",
"release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273"
}
Expand Down
12 changes: 10 additions & 2 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {
};

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{2304, 768, 4096},
};
{768, 3072, 4096},
};

// A, B, Bias, Gelu, D
// Bias type choose as bf16 in use_fp8, D_type otherwise
Expand Down Expand Up @@ -228,6 +228,14 @@ void performTest(const TestParams& params) {

#ifdef __HIP_PLATFORM_AMD__

#if HIP_VERSION < 70200000
if (prop.major == 9 && prop.minor == 5 &&
params.transa && !params.transb &&
params.m == 2304 && params.k == 768 && params.n == 4096) {
GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2";
}
#endif

// Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0.
// hipBLASLt currently supports this config only
bool fp8_gelu_fusion_config = false;
Expand Down
7 changes: 7 additions & 0 deletions tests/pytorch/test_fused_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -9,6 +11,7 @@
import pytest
import torch
from torch import nn
from torch.utils.cpp_extension import IS_HIP_EXTENSION
from torch.testing._internal.common_device_type import largeTensorTest
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import DelayedScaling
Expand All @@ -17,6 +20,7 @@
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
from transformer_engine.pytorch.utils import gpu_autocast_ctx
from transformer_engine.pytorch.utils import get_device_compute_capability

# Check if FP8 is supported
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
Expand Down Expand Up @@ -378,6 +382,7 @@ def test_bf16_exp_avg(self):
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
def test_fp8_exp_avg(self):
model_tol = 3e-2 if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) else None
self.gen_precision_aware_test(
use_fp8_params=False,
param_dtype=torch.bfloat16,
Expand All @@ -388,6 +393,8 @@ def test_fp8_exp_avg(self):
exp_avg_sq_dtype=torch.float32,
master_rtol=1e-2,
master_atol=1e-2,
model_rtol=model_tol,
model_atol=model_tol,
)

@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
Expand Down
12 changes: 12 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@
)
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
from transformer_engine.pytorch.utils import is_bf16_compatible
from transformer_engine.pytorch.utils import get_device_compute_capability
import transformer_engine_torch as tex
from torch.utils.cpp_extension import IS_HIP_EXTENSION

# Import utility functions
_current_file = pathlib.Path(__file__).resolve()
Expand Down Expand Up @@ -918,6 +920,16 @@ def test_basic_linear_quantized(
quantized_grad_input: bool,
) -> None:
"""GEMM with FP8 inputs and outputs"""
if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5):
if (
quantization
and quantization.startswith("fp8")
and quantized_compute
and (quantized_grad_input or quantized_output)
):
pytest.skip(
"hipBLASLt does not provide suitable algorithms on gfx950 for this config."
)
if quantization is None:
pytest.skip("Skipping case without quantization")
self._test_basic_linear(
Expand Down
26 changes: 22 additions & 4 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,6 +757,13 @@ def test_gpt_full_activation_recompute(
pytest.skip("FP8 parameters are not supported in debug mode.")
if recipe.float8_block_scaling() and not fp8_block_scaling_available:
pytest.skip(reason_for_no_fp8_block_scaling)
if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5):
if (dtype == torch.bfloat16
and not fp8
and not use_reentrant
and recipe.float8_per_tensor_scaling()
):
pytest.skip("hipBLASLt does not provide suitable algorithms on MI350 for this config.")

config = model_configs[model]
torch.compiler.reset() # avoid cache size limit overflow
Expand Down Expand Up @@ -2806,10 +2813,21 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
max_seqlen_kv=config.seq_len,
)

torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)
if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5):
tols_thd = dtype_tols(dtype)
# On gfx950 the results for THD are different
# that results in lower final result precision
tols_thd["atol"] = 2e-3
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
**tols_thd,
)
else:
torch.testing.assert_close(
y_bshd,
y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(),
)


@pytest.mark.parametrize("dtype", param_types)
Expand Down
14 changes: 12 additions & 2 deletions transformer_engine/jax/csrc/extensions/gemm.cpp
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
/*************************************************************************
* This file was modified for portability to AMDGPU
* Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
*
* See LICENSE for license information.
************************************************************************/
#include "transformer_engine/gemm.h"

#include <initializer_list>
#include <memory>
#include <string_view>
#include <tuple>
Expand All @@ -21,6 +24,13 @@
namespace transformer_engine {
namespace jax {

#ifdef USE_ROCM
// hipblaslt GEMM is not graph-capture safe on ROCm.
constexpr auto GemmFFI_CudaGraph_Traits = std::initializer_list<xla::ffi::Traits>{};
#else
constexpr auto GemmFFI_CudaGraph_Traits = FFI_CudaGraph_Traits;
#endif

static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) {
// Move the pointer to the next 256B aligned address
return reinterpret_cast<uint8_t *>((reinterpret_cast<uintptr_t>(ptr) + 255) &
Expand Down Expand Up @@ -200,7 +210,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI,
.Attr<bool>("fuse_gelu")
.Attr<bool>("grad")
.Attr<bool>("use_split_accumulator"),
FFI_CudaGraph_Traits);
GemmFFI_CudaGraph_Traits);

Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv,
Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias,
Expand Down Expand Up @@ -593,7 +603,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI,
.Attr<JAXX_Scaling_Mode>("scaling_mode")
.Attr<bool>("has_bias")
.Attr<bool>("is_grouped_dense_wgrad"),
FFI_CudaGraph_Traits);
GemmFFI_CudaGraph_Traits);

} // namespace jax
} // namespace transformer_engine
4 changes: 3 additions & 1 deletion transformer_engine/jax/quantize/device_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,9 @@ def get_device_compute_capability(gpu_id: int = 0) -> int:
def is_fp8_gemm_with_all_layouts_supported() -> bool:
"""Return True if using Blackwell architecture, False otherwise."""
compute_capability = get_device_compute_capability()
# Enable once FP8 GEMM layout coverage is validated with hipblaslt.
if is_hip_extension():
# gfx950 --> NV blackwell
return compute_capability == 95
# return compute_capability == 95
return False
return 100 <= compute_capability < 120
Loading
Loading