Skip to content
Open
Show file tree
Hide file tree
Changes from 130 commits
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
1d903f5
Changed VERSION to 2.5.0.dev0
ptrendx May 17, 2025
2645eae
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 3 – tests (#1612)
pggPL May 19, 2025
7be4339
Fix README render for uploading package to PyPI (#1798)
ksivaman May 19, 2025
730fd11
Enhance recipe compatibility (#1724)
negvet May 19, 2025
3baaf3f
Fix split_overlap_ag `aggregate=True` chunk offset calculation (#1768)
guyueh1 May 20, 2025
201de5f
Use an empty torch tensor to indicate no fp8 information in extra_sta…
pstjohn May 20, 2025
3e50d53
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 4 – documentatio…
pggPL May 20, 2025
d35afe1
[PyTorch] Add docstring for CP load balancing (#1802)
cyanguwa May 20, 2025
cd11e00
Add missing docs for C API (#1803)
ksivaman May 21, 2025
097afc0
fix model parallel encoder to be properly sharded params (#1794)
sudhakarsingh27 May 21, 2025
9c436d5
[PyTorch] Fix saved_tensors access in Ops Fuser (#1807)
pggPL May 22, 2025
0cd1cd8
[JAX] Fix incorrectly skipped test_quantize_dbias tests (#1808)
jberchtold-nvidia May 22, 2025
6262280
Remove `comm_gemm_overlap` doc (#1815)
ksivaman May 22, 2025
00328ac
Build support for cuda 13 (#1809)
ksivaman May 22, 2025
b17f3f4
[JAX] Make primitive names more granular for better disabling granula…
jberchtold-nvidia May 22, 2025
1669b3f
Add docs for missing FP8 recipes. (#1816)
ksivaman May 22, 2025
e4c051f
[PyTorch] Activation ops support fusing backward pass with quantize (…
timmoon10 May 22, 2025
fe9a786
Fix test.sh scripts to test pure-JAX implementations (#1805)
jberchtold-nvidia May 23, 2025
cd37379
Fix the failing test cases in the CI (#1806)
ptrendx May 23, 2025
9627b07
Updated README - Added conda installation (#1826)
sbhavani May 27, 2025
30e3081
Fix multi-framework runtime lib loading (#1825)
ksivaman May 28, 2025
4732ed7
[JAX] Update jax_scaled_masked_softmax to match TE kernel implementat…
jberchtold-nvidia May 28, 2025
355c4e4
[JAX] FP8 GEMM via dot_general + direct quant (#1819)
phu0ngng May 28, 2025
c9e8e30
[JAX] Removes unneccessary reshapes for FP8 GEMM (#1820)
phu0ngng May 29, 2025
41909dc
[PyTorch] Linear op avoids saving input tensor if weight grad is not …
timmoon10 May 29, 2025
4292653
Avoid memory allocations and deallocations when creating NVTETensor (…
ptrendx May 29, 2025
855fa65
[JAX] Support SWA in CP Ring Attn THD striped sharding (#1810)
huanghua1994 May 29, 2025
204add8
Avoid searching unnecessary dirs for shared libs (#1801)
timmoon10 May 30, 2025
d5d7833
Quantizer update when recipe was changed (#1814)
negvet May 30, 2025
c6a9e26
[PyTorch][Jax] Add warning for missing SOs if both frameworks are ins…
ksivaman May 31, 2025
62f5c9e
[JAX] Use 1x quantization + jax transpose for performance for tensor-…
jberchtold-nvidia Jun 2, 2025
c141711
Minor improvements to runtime error checks during library loading (#1…
ksivaman Jun 2, 2025
f3d77f6
[JAX] Fix NVTETensor leak in attention.cpp (#1841)
jberchtold-nvidia Jun 3, 2025
8b3ba9d
Bump cuDNN FE (#1842)
ksivaman Jun 3, 2025
75fe560
Update list of authorized CI users (#1840)
mk-61 Jun 3, 2025
151a0af
[PyTorch] Miscellaneous fixes for attention (#1780)
cyanguwa Jun 3, 2025
97e493f
Remove deprecated global option for debug build (#1848)
ksivaman Jun 4, 2025
12af02f
Fix `NVTE_FRAMEWORK=all` installation (#1850)
ksivaman Jun 5, 2025
f64d145
[JAX] Fix 1x quantize kernel availability check on hopper (#1845)
jberchtold-nvidia Jun 5, 2025
557f0cb
Use versioned flavor of get driver entrypoint function (#1835)
ptrendx Jun 5, 2025
6123d7e
[JAX] Fix OTYPE for FP8 GEMM (#1838)
phu0ngng Jun 5, 2025
9985b02
[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scal…
zhongbozhu Jun 6, 2025
7948779
[JAX] GroupedQuantizer and GroupedScaledTensor (#1666)
phu0ngng Jun 6, 2025
05f3b57
[Common] Missing CUDA driver deallocations in Userbuffers (#1812)
denera Jun 6, 2025
beffb29
[PyTorch] Get `skip_fp8_weight_update` only in CUDA Graph Capturing (…
yaox12 Jun 7, 2025
fab7157
Fix all framework build from PR 1666 (#1857)
ksivaman Jun 7, 2025
f519e6e
FP8 Param support for offloading (#1823)
sanandaraj5597 Jun 9, 2025
fc18520
Use public API instead of removed private function in `te_llama.py` (…
janekb04 Jun 9, 2025
ddcda1f
Manage dependencies and add missing `einops` req (#1859)
ksivaman Jun 9, 2025
031c6cf
Python 3.12+ support (#1862)
ksivaman Jun 10, 2025
faee0e8
Support Context Parallel for Multi Latent Attention (MLA) (#1729)
yuzhongw-nvidia Jun 10, 2025
aedd7e1
pyproject.toml (#1852)
ksivaman Jun 10, 2025
0efc7da
[PyTorch] Fix backward compatibility for checkpoint loading (#1868)
ksivaman Jun 12, 2025
c293d3a
[PyTorch] Fix typo in GrouppedLinear (#1867)
pggPL Jun 12, 2025
5d01ef2
[JAX] GroupedDense v.2 without dynamic shape (#1721)
phu0ngng Jun 12, 2025
4d4f1ed
Cpu reload double buffer (#1695)
sanandaraj5597 Jun 12, 2025
c3b7c2a
Revert "[JAX] GroupedDense v.2 without dynamic shape" (#1874)
phu0ngng Jun 12, 2025
c9d7f3f
[JAX] GroupedDense v.2 without dynamic shape (#1875)
phu0ngng Jun 12, 2025
40a30a5
[PyTorch] Support L2Normalization basic op -> use for qk_norm (#1864)
negvet Jun 12, 2025
227961e
[JAX] Distinguish the reasons why fp8 / mxfp8 is not supported in uni…
huanghua1994 Jun 12, 2025
ecaf3e2
Fixes for JIT-able grouped_gemm (#1872)
phu0ngng Jun 12, 2025
d90ced7
Add support for overlapping wgrad NCCL AG with dgrad GEMM (#1849)
djns99 Jun 13, 2025
8d4bdbc
Optimize `/ops/fuser.py` by moving computation from `forward` to `__i…
janekb04 Jun 13, 2025
655512c
[PyTorch] Inference mode disables initializing quantized weights with…
timmoon10 Jun 13, 2025
e963e4a
[PyTorch] Add support for FP8 current scaling in operation-based API …
timmoon10 Jun 13, 2025
7b94bd9
[common] Added support of FP4 data type (#1779)
Oleg-Goncharov Jun 13, 2025
71c76b6
Add support for head_dim > 128 (#1797)
cyanguwa Jun 13, 2025
1ddfa0c
[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (#1851)
KshitijLakhani Jun 13, 2025
a69692a
Changed VERSION to 2.6.0.dev0
ptrendx Jun 13, 2025
01a504c
[JAX] Grouped GEMM & Dense support MXFP8 and handle empty matrices (#…
huanghua1994 Jun 16, 2025
8ce49c0
[Pytorch] Bugfix in te fusion ce implementation (#1879)
BestJuly Jun 16, 2025
ba8c923
Fix test case that assumes char is signed (#1881)
timmoon10 Jun 16, 2025
ae572af
[JAX] Fixes for L0_jax_distributed_unittest (#1884)
phu0ngng Jun 17, 2025
3a298e6
[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW (#1844)
phu0ngng Jun 18, 2025
766e3b7
[PyTorch] Use FP16 tols for distributed tests with TF32 compute (#1831)
timmoon10 Jun 19, 2025
7db72db
Fix cppunittest test.sh for editable installs (#1869)
jberchtold-nvidia Jun 25, 2025
c30e961
[PyTorch][MoE] Reduce CPU Overhead By Fuse Torch Empty Calls (#1793)
zhongbozhu Jun 26, 2025
23cf4ff
[PyTorch|common] Optimize unpadding kernel for FP8 (#1866)
xiaoxi-wangfj Jun 26, 2025
964c2ed
[PyTorch Debug] Fix the issue with PP (#1894)
pggPL Jun 26, 2025
1d1d323
[PyTorch Debug] Fixed the empty tensor bug in statistics computation …
pggPL Jun 26, 2025
0587ecf
Optimize reshaping tensors in the `te.ops.Sequential` implementation …
janekb04 Jun 26, 2025
5b16807
[JAX] Use keyword args for jit in_shardings and out_shardings (#1898)
jberchtold-nvidia Jun 26, 2025
cc0cb35
[PyTorch] Skip KV cache for sm89 and cuDNN < 9.12 (#1895)
cyanguwa Jun 26, 2025
9d173c9
Fix MLA CP Bugs (#1896)
yuzhongw-nvidia Jun 28, 2025
447de6d
Fix layernorm output shape in LayernormLinear (#1906)
guyueh1 Jul 1, 2025
21b780c
Enable use of internal tensors in Sequential (#1900)
janekb04 Jul 1, 2025
6f4310d
Added MCore FSDP support for TE (#1890)
sanandaraj5597 Jul 1, 2025
1ae1d22
[PyTorch Debug] Skip some of debug tests if FP8 is not available. (#1…
pggPL Jul 4, 2025
d26cc3a
Add test for `LayerNormMLP` implementation using `te.ops.Sequential` …
janekb04 Jul 8, 2025
9166d4d
Call `pre_(first_)forward` only when global state changes (#1917)
janekb04 Jul 8, 2025
9d031fb
[JAX BUILD] Fixes for JAX 0.7.0 (#1936)
phu0ngng Jul 8, 2025
2f25d12
[PyTorch] Fix setting `align_size` when FP8 is not initialized (#1926)
yaox12 Jul 9, 2025
637facc
[PyTorch] Tests for loading previously-generated checkpoints (#1899)
timmoon10 Jul 9, 2025
3c4dfff
[JAX] Fix grouped GEMM error on CUDA 12.9.1 & later (#1925)
huanghua1994 Jul 9, 2025
96ee717
[PyTorch][MoE] MXFP8 Support to Reduce CPU Overhead By Fuse Torch Emp…
zhongbozhu Jul 9, 2025
4c7095c
Fixed cpu overhead when doing DS cast (#1941)
sanandaraj5597 Jul 9, 2025
1dd8f62
[PyTorch debug] Run test_sanity with debug tools enabled. (#1908)
pggPL Jul 10, 2025
6489189
Optimize CUDA Graph memory, FP8 wrapper, and uneven PP support (#1234)
buptzyb Jul 10, 2025
62acae0
[PyTorch][MoE] Kernels fusions for the MoE router (#1883)
Autumn1998 Jul 10, 2025
31fc29a
[PyTorch] Make `MXFP8Tensor` unpickling function backward compatible …
timmoon10 Jul 11, 2025
0a7e9fe
[JAX] Capped HuggingFace datasets version for TE/JAX encoder examples…
denera Jul 11, 2025
11fecc4
[JAX] Update distributed LayerNormMLP test tolerance for L40 (#1901)
jberchtold-nvidia Jul 11, 2025
ac76d55
[JAX] Fixes for the grouped_gemm with MXFP8 (#1945)
phu0ngng Jul 11, 2025
37da2d3
Add backward fusions of dbias+quantize and dbias+dactivation+quantize…
janekb04 Jul 12, 2025
dc97cc9
[PyTorch] Optimize the performance of permute fusion kernels (#1927)
hxbai Jul 14, 2025
397c4be
[PyTorch] Fix bugs in router fusion (#1944)
Autumn1998 Jul 14, 2025
214e2a4
[JAX] GEMM custom op (#1855)
denera Jul 14, 2025
1c702b4
Run-time checks for CUDA and cuBLAS versions (#1938)
timmoon10 Jul 14, 2025
e7251f9
[JAX] Resolve test conflict in JAX helper tests (#1916)
emmanuel-ferdman Jul 15, 2025
6c52679
Bump up FA to 2.8.1 (#1949)
vcherepanov-nv Jul 16, 2025
c0c12e2
[JAX] Support Flax sharding constraints (#1933)
jberchtold-nvidia Jul 16, 2025
0a1499f
[Pytorch] Dynamo ONNX export support (#1497)
pggPL Jul 16, 2025
bda2993
Handle dtypes more carefully in multi-tensor Adam (#1888)
timmoon10 Jul 16, 2025
fa91ed7
mxfp8 (for all gemm layouts) is not supported on 120+ arch yet (#1939)
sudhakarsingh27 Jul 17, 2025
07afda9
[PyTorch] Add save_original_input in Linear/GroupedLinear to save mem…
hxbai Jul 17, 2025
ed75c2b
[JAX] Tighten Encoder Test tolerances (#1955)
phu0ngng Jul 17, 2025
5350f27
[JAX] Remove unneccessary MXFP8 scale_inv padding (#1954)
phu0ngng Jul 17, 2025
f8933bb
[Common] Optimize KV cache related kernels (#1914)
cyanguwa Jul 17, 2025
657c965
Update cudnn-frontend to 1.13.0 (#1960)
cyanguwa Jul 18, 2025
2d4644b
[JAX] Set `precision=HIGHEST` for the ref_grouped_gemm impl in the un…
phu0ngng Jul 18, 2025
86c5097
[Test] Enable cuDNN Norm tests in the CPP suite (#1957)
phu0ngng Jul 18, 2025
ca7407e
[JAX] Update tolerance of distributed layernorm MLP for FP8 (#1971)
jberchtold-nvidia Jul 19, 2025
b109ff3
[ROCm] merge NV upstream commit ca7407e onto ROCm TE commit 6bbd03c a…
wangye805 Nov 18, 2025
4d3ca4d
[ROCm] resolve the ifu conflicts in common dir
wangye805 Oct 23, 2025
5ce0afd
[ROCm] resolve the conflicts in jax extension
wangye805 Oct 30, 2025
c9c9126
[ROCm] resolve the conflicts in pytorch extension
wangye805 Oct 30, 2025
9730903
[ROCm] resolve the conflicts in setup/build/init
wangye805 Oct 30, 2025
51bdbb8
[ROCm] resolve the conflicts in cpp tests
wangye805 Nov 2, 2025
ba59f81
[ROCm] resolve pytorch pytest conflicts
wangye805 Nov 2, 2025
5842c24
[ROCm] resolve the conflicts in TE jax pytest
wangye805 Nov 6, 2025
c3a9517
[ROCm] fix the example conflict and address reviewer comments
wangye805 Nov 21, 2025
aaceb5a
[ROCm] merge dev to commit 653b5b4
wangye805 Nov 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
33 changes: 27 additions & 6 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ jobs:
- name: 'Dependencies'
run: |
apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake==3.21.0
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake==3.21.0 pybind11[global] ninja
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -42,8 +42,8 @@ jobs:
- name: 'Dependencies'
run: |
apt-get update
apt-get install -y git python3.9 pip ninja-build cudnn9-cuda-12
pip install cmake torch pydantic importlib-metadata>=1.0 packaging pybind11
apt-get install -y git python3.9 pip cudnn9-cuda-12
pip install cmake torch ninja pydantic importlib-metadata>=1.0 packaging pybind11 numpy einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -54,7 +54,6 @@ jobs:
NVTE_FRAMEWORK: pytorch
MAX_JOBS: 1
- name: 'Sanity check'
if: false # Sanity import test requires Flash Attention
run: python3 tests/pytorch/test_sanity_import.py
jax:
name: 'JAX'
Expand All @@ -63,6 +62,8 @@ jobs:
image: ghcr.io/nvidia/jax:jax
options: --user root
steps:
- name: 'Dependencies'
run: pip install pybind11[global]
- name: 'Checkout'
uses: actions/checkout@v3
with:
Expand All @@ -73,4 +74,24 @@ jobs:
NVTE_FRAMEWORK: jax
MAX_JOBS: 1
- name: 'Sanity check'
run: python tests/jax/test_sanity_import.py
run: python3 tests/jax/test_sanity_import.py
all:
name: 'All'
runs-on: ubuntu-latest
container:
image: ghcr.io/nvidia/jax:jax
options: --user root
steps:
- name: 'Dependencies'
run: pip install torch pybind11[global] einops onnxscript
- name: 'Checkout'
uses: actions/checkout@v3
with:
submodules: recursive
- name: 'Build'
run: pip install --no-build-isolation . -v --no-deps
env:
NVTE_FRAMEWORK: all
MAX_JOBS: 1
- name: 'Sanity check'
run: python3 tests/pytorch/test_sanity_import.py && python3 tests/jax/test_sanity_import.py
1 change: 1 addition & 0 deletions .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ jobs:
|| github.actor == 'lhb8125'
|| github.actor == 'kunlunl'
|| github.actor == 'pstjohn'
|| github.actor == 'mk-61'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ downloads/
.pytest_cache/
compile_commands.json
.nfs
tensor_dumps/
artifacts/
**/profiler_outputs/
**/times.csv
tensor_dumps/
transformer_engine/build_info.txt
transformer_engine/common/util/hip_nvml.*
33 changes: 23 additions & 10 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ Installation
============

System Requirements
^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^

* **Hardware:** Blackwell, Hopper, Grace Hopper/Blackwell, Ada, Ampere

Expand All @@ -467,10 +467,10 @@ System Requirements
* **Notes:** FP8 features require Compute Capability 8.9+ (Ada/Hopper/Blackwell)

Installation Methods
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^

Docker (Recommended)
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^
The quickest way to get started with Transformer Engine is by using Docker images on
`NVIDIA GPU Cloud (NGC) Catalog <https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch>`_.

Expand All @@ -495,7 +495,7 @@ Where 25.04 (corresponding to April 2025 release) is the container version.
* NGC PyTorch 23.08+ containers include FlashAttention-2

pip Installation
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^

**Prerequisites for pip installation:**

Expand All @@ -519,21 +519,33 @@ Alternatively, install directly from the GitHub repository:

.. code-block:: bash

pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

When installing from GitHub, you can explicitly specify frameworks using the environment variable:

.. code-block:: bash

NVTE_FRAMEWORK=pytorch,jax pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
NVTE_FRAMEWORK=pytorch,jax pip install --no-build-isolation git+https://github.com/NVIDIA/TransformerEngine.git@stable

conda Installation
^^^^^^^^^^^^^^^^^^

To install the latest stable version with conda from conda-forge:

.. code-block:: bash

# For PyTorch integration
conda install -c conda-forge transformer-engine-torch

# JAX integration (coming soon)

Source Installation
^^^^^^^^^^^^^^^^^^^

`See the installation guide <https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/installation.html#installation-from-source>`_

Environment Variables
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^
These environment variables can be set before installation to customize the build process:

* **CUDA_PATH**: Path to CUDA installation
Expand All @@ -544,7 +556,7 @@ These environment variables can be set before installation to customize the buil
* **NVTE_BUILD_THREADS_PER_JOB**: Control threads per build job

Compiling with FlashAttention
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Transformer Engine supports both FlashAttention-2 and FlashAttention-3 in PyTorch for improved performance. FlashAttention-3 was added in release v1.11 and is prioritized over FlashAttention-2 when both are present in the environment.

You can verify which FlashAttention version is being used by setting these environment variables:
Expand All @@ -556,8 +568,9 @@ You can verify which FlashAttention version is being used by setting these envir
It is a known issue that FlashAttention-2 compilation is resource-intensive and requires a large amount of RAM (see `bug <https://github.com/Dao-AILab/flash-attention/issues/358>`_), which may lead to out of memory errors during the installation of Transformer Engine. Please try setting **MAX_JOBS=1** in the environment to circumvent the issue.

.. troubleshooting-begin-marker-do-not-remove

Troubleshooting
^^^^^^^^^^^^^^^^^^^
^^^^^^^^^^^^^^^

**Common Issues and Solutions:**

Expand Down Expand Up @@ -691,7 +704,7 @@ Papers
Videos
======

* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`_
* `Stable and Scalable FP8 Deep Learning Training on Blackwell | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc24-s62457/>`__
* `Blackwell Numerics for AI | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72458/>`_
* `Building LLMs: Accelerating Pretraining of Foundational Models With FP8 Precision | GTC 2025 <https://www.nvidia.com/gtc/session-catalog/?regcode=no-ncid&ncid=no-ncid&tab.catalogallsessionstab=16566177511100015Kus&search=zoho#/session/1726152813607001vnYK>`_
* `From FP8 LLM Training to Inference: Language AI at Scale | GTC 2025 <https://www.nvidia.com/en-us/on-demand/session/gtc25-s72799/>`_
Expand Down
Loading