Skip to content

Commit 7d2ed36

Browse files
committed
Merge dev f141f34
2 parents b95717e + f141f34 commit 7d2ed36

File tree

29 files changed

+415
-923
lines changed

29 files changed

+415
-923
lines changed

.github/workflows/rocm-ci.yml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,13 @@ concurrency:
4040

4141
jobs:
4242
build_and_test:
43-
name: Build and Test on GPU
43+
name: Build and Test on GPU (${{ matrix.runner }})
4444
timeout-minutes: 720
45-
runs-on: linux-mi325-8
45+
runs-on: ${{ matrix.runner }}
46+
strategy:
47+
fail-fast: false
48+
matrix:
49+
runner: [linux-mi325-8, linux-mi355-8]
4650
steps:
4751
- name: Checkout repository
4852
uses: actions/checkout@v4
@@ -422,7 +426,7 @@ jobs:
422426
if: always()
423427
uses: actions/upload-artifact@v4
424428
with:
425-
name: logs-and-reports
429+
name: logs-and-reports-${{ matrix.runner }}
426430
path: |
427431
*.log
428432
if-no-files-found: ignore

README.rst

Lines changed: 52 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,60 @@ Feature Support Status
2828
Installation
2929
============
3030

31+
Install from manylinux wheels
32+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
33+
34+
Starting from ROCm 7.0, we provide manylinux wheels for Transformer Engine releases on `https://repo.radeon.com/rocm/manylinux`. For example, the wheels for ROCm 7.1.1 are at `https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/`. From the page, you can find four files related to Transformer Engine:
35+
36+
* transformer_engine_rocm-*-py3-none-manylinux_2_28_x86_64.whl - This is the wheel file for installing the common library. It should not be installed by itself.
37+
* transformer_engine-*-py3-none-any.whl - This is the wheel file for installing the common TE Python package.
38+
* transformer_engine_jax-*.tar.gz - This is the source tar ball for the JAX extension.
39+
* transformer_engine_torch-*.tar.gz - This is the source tar ball for the Pytorch extension.
40+
41+
Below are the example commands to download and install the wheels. They install both Pytorch and JAX extensions on the system where both frameworks are installed.
42+
43+
.. code-block:: bash
44+
45+
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_rocm-2.2.0-py3-none-manylinux_2_28_x86_64.whl
46+
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine-2.2.0-py3-none-any.whl
47+
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_jax-2.2.0.tar.gz
48+
wget https://repo.radeon.com/rocm/manylinux/rocm-rel-7.1.1/transformer_engine_torch-2.2.0.tar.gz
49+
50+
pip install ./transformer_engine* --no-build-isolation
51+
52+
Install TE from source
53+
^^^^^^^^^^^^^^^^^^
54+
3155
Execute the following commands to install ROCm Transformer Engine from source on AMDGPUs:
3256

57+
.. code-block:: bash
58+
59+
# Clone TE repo and submodules
60+
git clone --recursive https://github.com/ROCm/TransformerEngine.git
61+
62+
cd TransformerEngine
63+
export NVTE_FRAMEWORK=pytorch,jax #optionally set framework, currently only support pytorch and jax; if not set will try to detect installed frameworks
64+
export NVTE_ROCM_ARCH=gfx942,gfx950 # gfx942 for support of MI300/MI325, and gfx950 for support of MI350
65+
66+
# Build Platform Selection (optional)
67+
# Note: Useful when both ROCm and CUDA platforms are present in the Docker
68+
export NVTE_USE_ROCM=1 #Use 1 for ROCm, or set to 0 to use CUDA; If not set will try to detect installed platform, prioritizing ROCm
69+
70+
pip install . --no-build-isolation
71+
72+
It is also possible to build wheels for later installation with "pip wheel ." although those wheels will not be portable to systems with
73+
different libraries installed. If the build still fails with the "--no-build-isolation" flag try installing setuptools<80.0.0
74+
75+
Note on Switching between Installation from Source and Installation from Wheels
76+
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
77+
Sometimes, issues might occur when installing from source on a system where a previous installation with wheels, or vice versa. It is safe to uninstall TE first before
78+
switching between installing from source and installing from wheels. Here is the example command:
79+
80+
.. code-block:: bash
81+
82+
# The package name pattern might be transformer_engine or transformer-engine depending on setuptools version
83+
pip list | grep transformer.engine | xargs pip uninstall -y
84+
3385
Known Issue with ROCm 6.4 PyTorch Release
3486
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
3587

@@ -57,27 +109,6 @@ Re-install PyTorch
57109
./tools/amd_build/build_amd.py
58110
BUILD_TEST=0 python3 setup.py install
59111
60-
Install TE
61-
^^^^^^^^^^^^^^^^^^
62-
63-
.. code-block:: bash
64-
65-
# Clone TE repo and submodules
66-
git clone --recursive https://github.com/ROCm/TransformerEngine.git
67-
68-
cd TransformerEngine
69-
export NVTE_FRAMEWORK=pytorch,jax #optionally set framework, currently only support pytorch and jax; if not set will try to detect installed frameworks
70-
export NVTE_ROCM_ARCH=gfx942 # CK fused attn only support MI200 and MI300 and fp8 features are only supported on MI300
71-
72-
# Build Platform Selection (optional)
73-
# Note: Useful when both ROCm and CUDA platforms are present in the Docker
74-
export NVTE_USE_ROCM=1 #Use 1 for ROCm, or set to 0 to use CUDA; If not set will try to detect installed platform, prioritizing ROCm
75-
76-
pip install --no-build-isolation .
77-
78-
It is also possible to build wheels for later installation with "pip wheel ." although those wheels will not be portable to systems with
79-
different libraries installed. This build may also require "--no-build-isolation" and if the build still fails with this flag try installing setuptools<80.0.0
80-
81112
Test
82113
====
83114

ci/ci_config.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
{
22
"docker_images": {
3-
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.0.2_ubuntu22.04_py3.10_pytorch_release-2.7_9015dfdf_jax_v0.6.0_fa-v2.8.0",
3+
"default": "registry-sc-harbor.amd.com/framework/te-ci:rocm-7.1.1_ubuntu22.04_py3.11_pytorch_release_2.8_63e525b2_jax_0.7.1_fa-2.8.0",
44
"release_v1.13": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273",
55
"release_v1.14": "compute-artifactory.amd.com:5000/rocm-plus-docker/framework/private/te-ci:rocm-6.4_0_ubuntu22_py310_torch25_jax0435qa_fa273"
66
}

ci/pytorch.sh

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,13 @@ run_test_config(){
7575
NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py
7676
run_default_fa 1 test_parallel_cross_entropy.py
7777
NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py
78-
NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
78+
NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
7979
NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py
80-
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py
81-
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py
80+
NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 test_numerics.py
81+
NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 test_fusible_ops.py
82+
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_numerics.py
83+
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_fusible_ops.py
84+
NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 triton_kernels/test_cast.py
8285
}
8386

8487
run_test_config_mgpu(){

tests/cpp/operator/test_cublaslt_gemm.cu

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {
2929
};
3030

3131
std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
32-
{2304, 768, 4096},
33-
};
32+
{768, 3072, 4096},
33+
};
3434

3535
// A, B, Bias, Gelu, D
3636
// Bias type choose as bf16 in use_fp8, D_type otherwise
@@ -228,6 +228,14 @@ void performTest(const TestParams& params) {
228228

229229
#ifdef __HIP_PLATFORM_AMD__
230230

231+
#if HIP_VERSION < 70200000
232+
if (prop.major == 9 && prop.minor == 5 &&
233+
params.transa && !params.transb &&
234+
params.m == 2304 && params.k == 768 && params.n == 4096) {
235+
GTEST_SKIP() << "Skip TN 2304x768x4096 on gfx950 for ROCm < 7.2";
236+
}
237+
#endif
238+
231239
// Enable FP8 GEMM + GELU fusion tests only on MI300 (gfx942) with ROCm > 7.0.
232240
// hipBLASLt currently supports this config only
233241
bool fp8_gelu_fusion_config = false;
@@ -287,11 +295,15 @@ void performTest(const TestParams& params) {
287295
}
288296
if (prop.major == 9 && prop.minor == 4) //gfx942 specific hipblasLt limitations
289297
{
298+
#if HIP_VERSION < 70100000
290299
if (params.use_gelu && dtype == DType::kBFloat16 && !params.transa) {
291300
GTEST_SKIP() << "BF16 GEMM with GELU is not supported in current config";
292301
}
293-
if (has_fp8 && params.use_bias && dtype == DType::kFloat8E4M3 && !fp8_gelu_fusion_config) {
294-
GTEST_SKIP() << "FP8 GEMM with bias and FP8 output is not supported in current config";
302+
#endif
303+
if constexpr (std::is_same<D_Type, fp8>::value && std::is_same<Bias_Type, bf16>::value) {
304+
if (params.use_bias && !fp8_gelu_fusion_config) {
305+
GTEST_SKIP() << "GEMM with BF16 bias and FP8 output is not supported in current config";
306+
}
295307
}
296308
}
297309
#endif

tests/cpp/test_common.h

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -344,20 +344,18 @@ struct Numeric_Traits<fp8e4m3> {
344344
static constexpr double minSubnorm = 1.0 / static_cast<double>(1 << 9); // std::pow(2.0, -9.0);
345345
static constexpr double maxSubnorm = 0.875 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
346346
static constexpr double minNorm = 1.0 / static_cast<double>(1 << 6); // std::pow(2.0, -6.0);
347-
#ifndef USE_ROCM
347+
#ifndef USE_ROCM
348348
static constexpr double maxNorm = 448.0;
349-
#elif HIP_VERSION >= 60300000
349+
#else
350350
static const double maxNorm;
351-
#else
352-
static constexpr double maxNorm = 240.0;
353-
#endif //USE_ROCM
351+
#endif //USE_ROCM
354352
static const double artifInf; // artificial Infinity
355353
static constexpr int maxBiasedExponentAsFP32 = 8 + FP32_EXPONENT_BIAS;
356354
static constexpr int maxUnbiasedExponentAsFP32 = 8;
357355
static constexpr int maxExpNorm = 1 << maxUnbiasedExponentAsFP32;
358356
};
359357

360-
#if defined(USE_ROCM) && (HIP_VERSION >= 60300000)
358+
#ifdef USE_ROCM
361359
inline const double Numeric_Traits<fp8e4m3>::maxNorm = te_fp8_fnuz() ? 240.0 : 448.0;
362360
#endif
363361

tests/pytorch/test_fused_optimizer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
# This file was modified for portability to AMDGPU
2+
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
13
# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
24
#
35
# See LICENSE for license information.
@@ -8,6 +10,7 @@
810
import pytest
911
import torch
1012
from torch import nn
13+
from torch.utils.cpp_extension import IS_HIP_EXTENSION
1114
from torch.testing._internal.common_device_type import largeTensorTest
1215
import transformer_engine.pytorch as te
1316
from transformer_engine.common.recipe import DelayedScaling
@@ -16,6 +19,7 @@
1619
from transformer_engine.pytorch.utils import is_bf16_compatible
1720
from transformer_engine.pytorch.fp8 import FP8GlobalStateManager
1821
from transformer_engine.pytorch.utils import gpu_autocast_ctx
22+
from transformer_engine.pytorch.utils import get_device_compute_capability
1923

2024
# Check if FP8 is supported
2125
fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available()
@@ -370,6 +374,7 @@ def test_bf16_exp_avg(self):
370374
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")
371375
@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8)
372376
def test_fp8_exp_avg(self):
377+
model_tol = 3e-2 if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) else None
373378
self.gen_precision_aware_test(
374379
use_fp8_params=False,
375380
param_dtype=torch.bfloat16,
@@ -380,6 +385,8 @@ def test_fp8_exp_avg(self):
380385
exp_avg_sq_dtype=torch.float32,
381386
master_rtol=1e-2,
382387
master_atol=1e-2,
388+
model_rtol=model_tol,
389+
model_atol=model_tol,
383390
)
384391

385392
@pytest.mark.skipif(not is_bf16_compatible(), reason="bf16 if not supported")

tests/pytorch/test_fusible_ops.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,9 @@
3838
)
3939
from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor, MXFP8Quantizer
4040
from transformer_engine.pytorch.utils import is_bf16_compatible
41+
from transformer_engine.pytorch.utils import get_device_compute_capability
4142
import transformer_engine_torch as tex
43+
from torch.utils.cpp_extension import IS_HIP_EXTENSION
4244

4345
# Import utility functions
4446
from utils import dtype_tols, make_recipe, reset_rng_states
@@ -971,6 +973,16 @@ def test_basic_linear_quantized(
971973
quantized_grad_input: bool,
972974
) -> None:
973975
"""GEMM with FP8 inputs and outputs"""
976+
if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5):
977+
if (
978+
quantization
979+
and quantization.startswith("fp8")
980+
and quantized_compute
981+
and (quantized_grad_input or quantized_output)
982+
):
983+
pytest.skip(
984+
"hipBLASLt does not provide suitable algorithms on gfx950 for this config."
985+
)
974986
if quantization is None:
975987
pytest.skip("Skipping case without quantization")
976988
self._test_basic_linear(

tests/pytorch/test_numerics.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -768,6 +768,13 @@ def test_gpt_full_activation_recompute(
768768
):
769769
if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED:
770770
pytest.skip("FP8 parameters are not supported in debug mode.")
771+
if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5):
772+
if (dtype == torch.bfloat16
773+
and not fp8
774+
and not use_reentrant
775+
and recipe.float8_per_tensor_scaling()
776+
):
777+
pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.")
771778

772779
config = model_configs[model]
773780
torch.compiler.reset() # avoid cache size limit overflow
@@ -2829,10 +2836,21 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model):
28292836
max_seqlen_kv=config.max_seqlen_kv,
28302837
)
28312838

2832-
torch.testing.assert_close(
2833-
y_bshd,
2834-
y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
2835-
)
2839+
if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5):
2840+
tols_thd = dtype_tols(dtype)
2841+
# On gfx950 the results for THD are different
2842+
# that results in lower final result precision
2843+
tols_thd["atol"] = 2e-3
2844+
torch.testing.assert_close(
2845+
y_bshd,
2846+
y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
2847+
**tols_thd,
2848+
)
2849+
else:
2850+
torch.testing.assert_close(
2851+
y_bshd,
2852+
y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(),
2853+
)
28362854

28372855

28382856
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)