Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
02bc852
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 3 – tests (#1612)
pggPL May 19, 2025
74525d1
Fix README render for uploading package to PyPI (#1798)
ksivaman May 19, 2025
cea1152
Enhance recipe compatibility (#1724)
negvet May 19, 2025
610c393
Use an empty torch tensor to indicate no fp8 information in extra_sta…
pstjohn May 20, 2025
c5ea9eb
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 4 – documentatio…
pggPL May 20, 2025
aafa053
[PyTorch] Add docstring for CP load balancing (#1802)
cyanguwa May 20, 2025
90458e7
Add missing docs for C API (#1803)
ksivaman May 21, 2025
3a5ca57
Remove `comm_gemm_overlap` doc (#1815)
ksivaman May 22, 2025
9b80ea9
Add docs for missing FP8 recipes. (#1816)
ksivaman May 22, 2025
7558c44
Fix the failing test cases in the CI (#1806)
ptrendx May 23, 2025
d82f67b
Fix multi-framework runtime lib loading (#1825)
ksivaman May 28, 2025
b1d2539
Release v2.4_rocm
alextmagro Oct 6, 2025
0e1c8fe
readd HIP data generation
alextmagro Oct 7, 2025
758ed7e
Missing ; in test_common
alextmagro Oct 8, 2025
d1b8dba
[CI] Removed Jax jit workaround, replaced with XLA_FLAGS=--xla_gpu_en…
VeeraRajasekhar Oct 31, 2025
fa8615d
CI hotfix: IFU test update (#329)
Micky774 Oct 10, 2025
08bf8fc
Fix and add MXFP8 GEMM test failures (#326)
ipanfilo Oct 19, 2025
c6a2c65
Fix FFI import. Add distributed tests hang workaround (#347)
ipanfilo Oct 23, 2025
499d2d8
Make TE ROCm wheels building image directly from manylinix image (#340)
ipanfilo Oct 27, 2025
235b9b6
[CI] Hotfix test_gemm_autotune update (#353)
VeeraRajasekhar Oct 31, 2025
bcae459
MXFP8 test scale off by 1 fix (#338)
alextmagro Oct 31, 2025
34b1a34
CI: allow numpy 2.0 (#366)
ipanfilo Nov 7, 2025
736ab30
Relax tolerance to pass 29x29x17389NT GEMM on MI350 (#365)
ipanfilo Nov 8, 2025
baed0d1
Bring back aiter solib with aiter update (#327)
ipanfilo Oct 12, 2025
cc5b356
[ROCm] update AITER to support aiter shared lib for multi-gpu (PRs 11…
wangye805 Oct 22, 2025
08344fe
Use .info/version for ROCm verison (#368)
ipanfilo Nov 12, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,5 @@ compile_commands.json
**/profiler_outputs/
**/times.csv
tensor_dumps/
aiter/
transformer_engine/build_info.txt
transformer_engine/common/util/hip_nvml.*
transformer_engine/aiter/
2 changes: 1 addition & 1 deletion 3rdparty/aiter
Submodule aiter updated 413 files
17 changes: 9 additions & 8 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -450,7 +450,7 @@ Installation
============

System Requirements
^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^

* **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere

Expand All @@ -468,10 +468,10 @@ System Requirements
* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)

Installation Methods
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^

Docker (Recommended)
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.

Expand All @@ -496,7 +496,7 @@ Where 25.04 (corresponding to April 2025 release) is the container version.
* NGC PyTorch 23.08+ containers include FlashAttention-2

pip Installation
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^

**Prerequisites for pip installation:**

Expand Down Expand Up @@ -534,7 +534,7 @@ Source Installation
`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_

Environment Variables
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^
These environment variables can be set before installation to customize the build process:

* **CUDA_PATH**: Path to CUDA installation
Expand All @@ -545,7 +545,7 @@ These environment variables can be set before installation to customize the buil
* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job

Compiling with FlashAttention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment.

You can verify which FlashAttention version is being used by setting these environment variables:
Expand All @@ -557,8 +557,9 @@ You can verify which FlashAttention version is being used by setting these envir
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.

.. troubleshooting-begin-marker-do-not-remove

Troubleshooting
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^

**Common Issues and Solutions:**

Expand Down Expand Up @@ -692,7 +693,7 @@ Papers
Videos
======

* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`__
* `Blackwell Numerics for AI | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72458/>`_
* `Building LLMs: Accelerating Pretraining of Foundational Models With FP8 Precision | GTC 2025 <https://www.nvidia.com/gtc/session-catalog/?regcode=no-ncid&ncid=no-ncid&tab.catalogallsessionstab=16566177511100015Kus&search=zoho#/session/1726152813607001vnYK>`_
* `From FP8 LLM Training to Inference: Language AI at Scale | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72799/>`_
Expand Down
7 changes: 6 additions & 1 deletion build_tools/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,12 @@ def xla_path() -> str:
Throws FileNotFoundError if XLA source is not found."""

try:
from jax.extend import ffi
import jax
from packaging import version
if version.parse(jax.__version__) >= version.parse("0.5.0"):
from jax import ffi
else:
from jax.extend import ffi
except ImportError:
if os.getenv("XLA_HOME"):
xla_home = Path(os.getenv("XLA_HOME"))
Expand Down
34 changes: 22 additions & 12 deletions build_tools/wheel_utils/Dockerfile.rocm.manylinux.x86
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,32 @@
#
# See LICENSE for license information.

# This Dockerfile is used to build TransformerEngine wheels for ROCm on x86_64 architecture.
# It is based on the manylinux_2_28_x86_64 based image with ROCm installed.
ARG BASE_IMAGE=quay.io/pypa/manylinux_2_28_x86_64:non_existent_rocm_tag
# This Dockerfile is used to build TransformerEngine wheels for ROCm on x86_64 architecture
# on top of the manylinux_2_28_x86_64 base image.

# Build args:
# BASE_IMAGE - Base manylinux image to use. Default: quay.io/pypa/manylinux_2_28_x86_64
# ROCM_REPO_URL - ROCm repository URL. Default: https://repo.radeon.com/rocm/rhel8/latest/main/
# GPU_TARGETS - Semicolon separated list of target GPU architectures. Default: "gfx942;gfx950"
# TARGET_BRANCH - Target branch for TransformerEngine. Default: none (use git default)
# GPU_TARGETS and TARGET_BRANCH can be overriden when start a container with NVTE_ROCM_ARCH and TARGET_BRANCH environment variables.

# Set base image
ARG BASE_IMAGE=quay.io/pypa/manylinux_2_28_x86_64
FROM $BASE_IMAGE

# Setup the build_system repo
RUN echo -e "[build_system]\nname=ROCm\nbaseurl=https://repo.almalinux.org/build_system/8/x86_64/\nenabled=1\ngpgcheck=0" >/etc/yum.repos.d/build_system.repo
ARG ROCM_REPO_URL=https://repo.radeon.com/rocm/rhel8/latest/main/

# Add and enable repos
RUN dnf update -y || true
RUN dnf install -y epel-release elrepo-release
RUN dnf config-manager --set-enabled build_system powertools extras epel elrepo
# Set up ROCm repo
RUN echo -e "[rocm]\nname=ROCm\nbaseurl=${ROCM_REPO_URL}\nenabled=1\ngpgcheck=0" > /etc/yum.repos.d/rocm.repo

# Setup packages
RUN dnf install -y --disablerepo=epel rocm-dev hipblaslt hipblaslt-devel hipcub hipcub-devel
RUN dnf group install -y "Development Tools" && dnf install -y git cmake llvm-toolset gcc-toolset-12

#Uncomment the next line for ROCm 6.4 cmake workaround: remove newer incomnpatible cmake preinstalled on base image
#RUN rm /usr/local/bin/cmake || true

# Setup dev packages
RUN dnf group install -y "Development Tools" && \
dnf install -y git cmake llvm-toolset hipblaslt hipblaslt-devel gcc-toolset-12
RUN dnf clean all
RUN rm -rf /var/cache/dnf/*

Expand Down
20 changes: 13 additions & 7 deletions build_tools/wheel_utils/build_wheels.sh
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,17 @@ fi

ROCM_BUILD=`${PYBINDIR}python -c "import build_tools.utils as u; print(int(u.rocm_build()))"`

if [ "$LOCAL_TREE_BUILD" != "1" ]; then
if [ "$ROCM_BUILD" = "1" ]; then
git pull
fi
git checkout $TARGET_BRANCH
git submodule update --init --recursive
fi

if [ "$ROCM_BUILD" = "1" ]; then
git pull
${PYBINDIR}pip install setuptools wheel
fi
git checkout $TARGET_BRANCH
git submodule update --init --recursive

if $BUILD_METAPACKAGE ; then
cd /TransformerEngine
Expand All @@ -50,10 +56,10 @@ if $BUILD_COMMON ; then
WHL_BASE="transformer_engine-${VERSION}"
if [ "$ROCM_BUILD" = "1" ]; then
TE_CUDA_VERS="rocm"
${PYBINDIR}pip install ninja dataclasses
if [ -n "$PYBINDIR" ]; then
PATH="$PYBINDIR:$PATH" #hipify expects python in PATH
fi
#dataclasses, psutil are needed for AITER
${PYBINDIR}pip install ninja dataclasses psutil
#hipify expects python in PATH, also ninja may be installed to python bindir
test -n "$PYBINDIR" && PATH="$PYBINDIR:$PATH" || true
else
TE_CUDA_VERS="cu12"
PYBINDIR=/opt/python/cp38-cp38/bin/
Expand Down
4 changes: 2 additions & 2 deletions ci/core.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@ fi
check_test_filter "nongemm"
if [ $? -eq 0 ]; then
echo ===== Run non GEMM tests =====
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -E "OperatorTest/GEMMTestSuite"
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -E "GEMMTestSuite"
test $? -eq 0 || test_run_error "non-GEMM"
fi

check_test_filter "gemm"
if [ $? -eq 0 ]; then
echo ===== Run GEMM tests =====
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -R "OperatorTest/GEMMTestSuite"
ctest --test-dir build -j"$n_parallel_jobs" -V --output-on-failure -R "GEMMTestSuite"
test $? -eq 0 || test_run_error "GEMM"
fi

Expand Down
24 changes: 7 additions & 17 deletions ci/jax.sh
Original file line number Diff line number Diff line change
Expand Up @@ -66,33 +66,23 @@ run_test_config() {
run_test_config_mgpu() {
echo ==== Run mGPU with Fused attention backend: $_fus_attn ====

_JAX_DISABLE_JIT_FLAG=${JAX_DISABLE_JIT:-0}
_ver=$(pip show jaxlib | grep Version)
case "$_ver" in
*0.4.35*)
# Workaround for distributed tests hang with JIT enabled
JAX_DISABLE_JIT=1 run 3 test_distributed_fused_attn.py -k 'not (test_context_parallel_allgather_attn[BALANCED or test_context_parallel_ring_attn)'
_JAX_DISABLE_JIT_FLAG=1

# Run tests that fail with JIT disabled
#run_lbl "allgather_balanced" 3 test_distributed_fused_attn.py -k 'test_context_parallel_allgather_attn[BALANCED'

# Workaround for distributed tests hang with xla_flag
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_ring_attn'

# Test ring attention with xla_flag --xla_experimental_ignore_channel_id only
# TODO: remove this flag after jax/xla update
XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
;;
*0.6.*)
# Workaround for distributed tests hang with JIT enabled
JAX_DISABLE_JIT=1 run 3 test_distributed_fused_attn.py -k 'not test_context_parallel_allgather_attn[BALANCED'
_JAX_DISABLE_JIT_FLAG=1
XLA_FLAGS="--xla_experimental_ignore_channel_id" run_lbl "parallel_ring" 3 test_distributed_fused_attn.py -k test_context_parallel_ring_attn
;;
*)
run 3 test_distributed_fused_attn.py
# Workaround for distributed tests hang with xla_flag
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run 3 test_distributed_fused_attn.py
;;
esac

run_default_fa 3 test_distributed_layernorm.py
JAX_DISABLE_JIT=$_JAX_DISABLE_JIT_FLAG run_default_fa 3 test_distributed_layernorm_mlp.py
XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run_default_fa 3 test_distributed_layernorm_mlp.py
run_default_fa 3 test_distributed_softmax.py

run_default_fa 3 test_sanity_import.py
Expand Down
4 changes: 2 additions & 2 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ TEST_DIR=${TE_PATH}tests/pytorch
#: ${TEST_WORKERS:=4}

install_prerequisites() {
pip install 'numpy>=1.22.4,<2.0' pandas
pip install 'numpy>=1.22.4' pandas
rc=$?
if [ $rc -ne 0 ]; then
script_error "Failed to install test prerequisites"
Expand Down Expand Up @@ -58,7 +58,7 @@ run_test_config(){
run_default_fa 1 test_deferred_init.py
run_default_fa 1 test_float8tensor.py
run_default_fa 1 test_float8_current_scaling_exact.py
run_default_fa 1 test_cpu_offloading.py
test $_fus_attn = auto -o $_fus_attn = ck -o $_fus_attn = aotriton && NVTE_FLASH_ATTN=0 run 1 test_cpu_offloading.py
run_default_fa 1 test_fused_rope.py
run_default_fa 1 test_fusible_ops.py
run_default_fa 3 test_gemm_autotune.py
Expand Down
9 changes: 9 additions & 0 deletions docs/api/c/cast_transpose_noop.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

cast_transpose_noop.h
=====================

.. doxygenfile:: cast_transpose_noop.h
9 changes: 9 additions & 0 deletions docs/api/c/cudnn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

cudnn.h
=======

.. doxygenfile:: cudnn.h
3 changes: 3 additions & 0 deletions docs/api/c/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,13 @@ directly from C/C++, without Python.

transformer_engine.h <transformer_engine>
activation.h <activation>
cast_transpose_noop.h <cast_transpose_noop>
cast.h <cast>
cudnn.h <cudnn>
fused_attn.h <fused_attn>
fused_rope.h <fused_rope>
gemm.h <gemm>
multi_tensor.h <multi_tensor>
normalization.h <normalization>
padding.h <padding>
permutation.h <permutation>
Expand Down
9 changes: 9 additions & 0 deletions docs/api/c/multi_tensor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.

multi_tensor.h
==============

.. doxygenfile:: multi_tensor.h
4 changes: 4 additions & 0 deletions docs/api/common.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ Common API
.. autoapiclass:: transformer_engine.common.recipe.DelayedScaling(margin=0, fp8_format=Format.HYBRID, amax_history_len=1024, amax_compute_algo="max", scaling_factor_compute_algo=None)

.. autoapiclass:: transformer_engine.common.recipe.MXFP8BlockScaling(fp8_format=Format.E4M3)

.. autoapiclass:: transformer_engine.common.recipe.Float8CurrentScaling(fp8_format=Format.HYBRID)

.. autoapiclass:: transformer_engine.common.recipe.Float8BlockScaling(fp8_format=Format.E4M3)
14 changes: 14 additions & 0 deletions docs/debug.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
..
Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.

See LICENSE for license information.
Precision debug tools
==============================================

.. toctree::
:caption: Precision debug tools

debug/1_getting_started.rst
debug/2_config_file_structure.rst
debug/api
debug/4_distributed.rst
Loading