Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
132 commits
Select commit Hold shift + click to select a range
1d903f5
Changed VERSION to 2.5.0.dev0
ptrendx May 17, 2025
2645eae
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 3 – tests (#1612)
pggPL May 19, 2025
7be4339
Fix README render for uploading package to PyPI (#1798)
ksivaman May 19, 2025
730fd11
Enhance recipe compatibility (#1724)
negvet May 19, 2025
3baaf3f
Fix split_overlap_ag `aggregate=True` chunk offset calculation (#1768)
guyueh1 May 20, 2025
201de5f
Use an empty torch tensor to indicate no fp8 information in extra_sta…
pstjohn May 20, 2025
3e50d53
[Pytorch] NVIDIA-DL-Framework-Inspect support – part 4 – documentatio…
pggPL May 20, 2025
d35afe1
[PyTorch] Add docstring for CP load balancing (#1802)
cyanguwa May 20, 2025
cd11e00
Add missing docs for C API (#1803)
ksivaman May 21, 2025
097afc0
fix model parallel encoder to be properly sharded params (#1794)
sudhakarsingh27 May 21, 2025
9c436d5
[PyTorch] Fix saved_tensors access in Ops Fuser (#1807)
pggPL May 22, 2025
0cd1cd8
[JAX] Fix incorrectly skipped test_quantize_dbias tests (#1808)
jberchtold-nvidia May 22, 2025
6262280
Remove `comm_gemm_overlap` doc (#1815)
ksivaman May 22, 2025
00328ac
Build support for cuda 13 (#1809)
ksivaman May 22, 2025
b17f3f4
[JAX] Make primitive names more granular for better disabling granula…
jberchtold-nvidia May 22, 2025
1669b3f
Add docs for missing FP8 recipes. (#1816)
ksivaman May 22, 2025
e4c051f
[PyTorch] Activation ops support fusing backward pass with quantize (…
timmoon10 May 22, 2025
fe9a786
Fix test.sh scripts to test pure-JAX implementations (#1805)
jberchtold-nvidia May 23, 2025
cd37379
Fix the failing test cases in the CI (#1806)
ptrendx May 23, 2025
9627b07
Updated README - Added conda installation (#1826)
sbhavani May 27, 2025
30e3081
Fix multi-framework runtime lib loading (#1825)
ksivaman May 28, 2025
4732ed7
[JAX] Update jax_scaled_masked_softmax to match TE kernel implementat…
jberchtold-nvidia May 28, 2025
355c4e4
[JAX] FP8 GEMM via dot_general + direct quant (#1819)
phu0ngng May 28, 2025
c9e8e30
[JAX] Removes unneccessary reshapes for FP8 GEMM (#1820)
phu0ngng May 29, 2025
41909dc
[PyTorch] Linear op avoids saving input tensor if weight grad is not …
timmoon10 May 29, 2025
4292653
Avoid memory allocations and deallocations when creating NVTETensor (…
ptrendx May 29, 2025
855fa65
[JAX] Support SWA in CP Ring Attn THD striped sharding (#1810)
huanghua1994 May 29, 2025
204add8
Avoid searching unnecessary dirs for shared libs (#1801)
timmoon10 May 30, 2025
d5d7833
Quantizer update when recipe was changed (#1814)
negvet May 30, 2025
c6a9e26
[PyTorch][Jax] Add warning for missing SOs if both frameworks are ins…
ksivaman May 31, 2025
62f5c9e
[JAX] Use 1x quantization + jax transpose for performance for tensor-…
jberchtold-nvidia Jun 2, 2025
c141711
Minor improvements to runtime error checks during library loading (#1…
ksivaman Jun 2, 2025
f3d77f6
[JAX] Fix NVTETensor leak in attention.cpp (#1841)
jberchtold-nvidia Jun 3, 2025
8b3ba9d
Bump cuDNN FE (#1842)
ksivaman Jun 3, 2025
75fe560
Update list of authorized CI users (#1840)
mk-61 Jun 3, 2025
151a0af
[PyTorch] Miscellaneous fixes for attention (#1780)
cyanguwa Jun 3, 2025
97e493f
Remove deprecated global option for debug build (#1848)
ksivaman Jun 4, 2025
12af02f
Fix `NVTE_FRAMEWORK=all` installation (#1850)
ksivaman Jun 5, 2025
f64d145
[JAX] Fix 1x quantize kernel availability check on hopper (#1845)
jberchtold-nvidia Jun 5, 2025
557f0cb
Use versioned flavor of get driver entrypoint function (#1835)
ptrendx Jun 5, 2025
6123d7e
[JAX] Fix OTYPE for FP8 GEMM (#1838)
phu0ngng Jun 5, 2025
9985b02
[PyTorch] FP8 Subchannel Recipe With FP8 Gather And Configurable Scal…
zhongbozhu Jun 6, 2025
7948779
[JAX] GroupedQuantizer and GroupedScaledTensor (#1666)
phu0ngng Jun 6, 2025
05f3b57
[Common] Missing CUDA driver deallocations in Userbuffers (#1812)
denera Jun 6, 2025
beffb29
[PyTorch] Get `skip_fp8_weight_update` only in CUDA Graph Capturing (…
yaox12 Jun 7, 2025
fab7157
Fix all framework build from PR 1666 (#1857)
ksivaman Jun 7, 2025
f519e6e
FP8 Param support for offloading (#1823)
sanandaraj5597 Jun 9, 2025
fc18520
Use public API instead of removed private function in `te_llama.py` (…
janekb04 Jun 9, 2025
ddcda1f
Manage dependencies and add missing `einops` req (#1859)
ksivaman Jun 9, 2025
031c6cf
Python 3.12+ support (#1862)
ksivaman Jun 10, 2025
faee0e8
Support Context Parallel for Multi Latent Attention (MLA) (#1729)
yuzhongw-nvidia Jun 10, 2025
aedd7e1
pyproject.toml (#1852)
ksivaman Jun 10, 2025
0efc7da
[PyTorch] Fix backward compatibility for checkpoint loading (#1868)
ksivaman Jun 12, 2025
c293d3a
[PyTorch] Fix typo in GrouppedLinear (#1867)
pggPL Jun 12, 2025
5d01ef2
[JAX] GroupedDense v.2 without dynamic shape (#1721)
phu0ngng Jun 12, 2025
4d4f1ed
Cpu reload double buffer (#1695)
sanandaraj5597 Jun 12, 2025
c3b7c2a
Revert "[JAX] GroupedDense v.2 without dynamic shape" (#1874)
phu0ngng Jun 12, 2025
c9d7f3f
[JAX] GroupedDense v.2 without dynamic shape (#1875)
phu0ngng Jun 12, 2025
40a30a5
[PyTorch] Support L2Normalization basic op -> use for qk_norm (#1864)
negvet Jun 12, 2025
227961e
[JAX] Distinguish the reasons why fp8 / mxfp8 is not supported in uni…
huanghua1994 Jun 12, 2025
ecaf3e2
Fixes for JIT-able grouped_gemm (#1872)
phu0ngng Jun 12, 2025
d90ced7
Add support for overlapping wgrad NCCL AG with dgrad GEMM (#1849)
djns99 Jun 13, 2025
8d4bdbc
Optimize `/ops/fuser.py` by moving computation from `forward` to `__i…
janekb04 Jun 13, 2025
655512c
[PyTorch] Inference mode disables initializing quantized weights with…
timmoon10 Jun 13, 2025
e963e4a
[PyTorch] Add support for FP8 current scaling in operation-based API …
timmoon10 Jun 13, 2025
7b94bd9
[common] Added support of FP4 data type (#1779)
Oleg-Goncharov Jun 13, 2025
71c76b6
Add support for head_dim > 128 (#1797)
cyanguwa Jun 13, 2025
1ddfa0c
[JAX] Add support for Fused Attn MLA head_dim_qk != head_dim_v (#1851)
KshitijLakhani Jun 13, 2025
a69692a
Changed VERSION to 2.6.0.dev0
ptrendx Jun 13, 2025
01a504c
[JAX] Grouped GEMM & Dense support MXFP8 and handle empty matrices (#…
huanghua1994 Jun 16, 2025
8ce49c0
[Pytorch] Bugfix in te fusion ce implementation (#1879)
BestJuly Jun 16, 2025
ba8c923
Fix test case that assumes char is signed (#1881)
timmoon10 Jun 16, 2025
ae572af
[JAX] Fixes for L0_jax_distributed_unittest (#1884)
phu0ngng Jun 17, 2025
3a298e6
[JAX] TensorUsage + FP8 GEMM with all layouts handling on BW (#1844)
phu0ngng Jun 18, 2025
766e3b7
[PyTorch] Use FP16 tols for distributed tests with TF32 compute (#1831)
timmoon10 Jun 19, 2025
7db72db
Fix cppunittest test.sh for editable installs (#1869)
jberchtold-nvidia Jun 25, 2025
c30e961
[PyTorch][MoE] Reduce CPU Overhead By Fuse Torch Empty Calls (#1793)
zhongbozhu Jun 26, 2025
23cf4ff
[PyTorch|common] Optimize unpadding kernel for FP8 (#1866)
xiaoxi-wangfj Jun 26, 2025
964c2ed
[PyTorch Debug] Fix the issue with PP (#1894)
pggPL Jun 26, 2025
1d1d323
[PyTorch Debug] Fixed the empty tensor bug in statistics computation …
pggPL Jun 26, 2025
0587ecf
Optimize reshaping tensors in the `te.ops.Sequential` implementation …
janekb04 Jun 26, 2025
5b16807
[JAX] Use keyword args for jit in_shardings and out_shardings (#1898)
jberchtold-nvidia Jun 26, 2025
cc0cb35
[PyTorch] Skip KV cache for sm89 and cuDNN < 9.12 (#1895)
cyanguwa Jun 26, 2025
9d173c9
Fix MLA CP Bugs (#1896)
yuzhongw-nvidia Jun 28, 2025
447de6d
Fix layernorm output shape in LayernormLinear (#1906)
guyueh1 Jul 1, 2025
21b780c
Enable use of internal tensors in Sequential (#1900)
janekb04 Jul 1, 2025
6f4310d
Added MCore FSDP support for TE (#1890)
sanandaraj5597 Jul 1, 2025
1ae1d22
[PyTorch Debug] Skip some of debug tests if FP8 is not available. (#1…
pggPL Jul 4, 2025
d26cc3a
Add test for `LayerNormMLP` implementation using `te.ops.Sequential` …
janekb04 Jul 8, 2025
9166d4d
Call `pre_(first_)forward` only when global state changes (#1917)
janekb04 Jul 8, 2025
9d031fb
[JAX BUILD] Fixes for JAX 0.7.0 (#1936)
phu0ngng Jul 8, 2025
2f25d12
[PyTorch] Fix setting `align_size` when FP8 is not initialized (#1926)
yaox12 Jul 9, 2025
637facc
[PyTorch] Tests for loading previously-generated checkpoints (#1899)
timmoon10 Jul 9, 2025
3c4dfff
[JAX] Fix grouped GEMM error on CUDA 12.9.1 & later (#1925)
huanghua1994 Jul 9, 2025
96ee717
[PyTorch][MoE] MXFP8 Support to Reduce CPU Overhead By Fuse Torch Emp…
zhongbozhu Jul 9, 2025
4c7095c
Fixed cpu overhead when doing DS cast (#1941)
sanandaraj5597 Jul 9, 2025
1dd8f62
[PyTorch debug] Run test_sanity with debug tools enabled. (#1908)
pggPL Jul 10, 2025
6489189
Optimize CUDA Graph memory, FP8 wrapper, and uneven PP support (#1234)
buptzyb Jul 10, 2025
62acae0
[PyTorch][MoE] Kernels fusions for the MoE router (#1883)
Autumn1998 Jul 10, 2025
31fc29a
[PyTorch] Make `MXFP8Tensor` unpickling function backward compatible …
timmoon10 Jul 11, 2025
0a7e9fe
[JAX] Capped HuggingFace datasets version for TE/JAX encoder examples…
denera Jul 11, 2025
11fecc4
[JAX] Update distributed LayerNormMLP test tolerance for L40 (#1901)
jberchtold-nvidia Jul 11, 2025
ac76d55
[JAX] Fixes for the grouped_gemm with MXFP8 (#1945)
phu0ngng Jul 11, 2025
37da2d3
Add backward fusions of dbias+quantize and dbias+dactivation+quantize…
janekb04 Jul 12, 2025
dc97cc9
[PyTorch] Optimize the performance of permute fusion kernels (#1927)
hxbai Jul 14, 2025
397c4be
[PyTorch] Fix bugs in router fusion (#1944)
Autumn1998 Jul 14, 2025
214e2a4
[JAX] GEMM custom op (#1855)
denera Jul 14, 2025
1c702b4
Run-time checks for CUDA and cuBLAS versions (#1938)
timmoon10 Jul 14, 2025
e7251f9
[JAX] Resolve test conflict in JAX helper tests (#1916)
emmanuel-ferdman Jul 15, 2025
6c52679
Bump up FA to 2.8.1 (#1949)
vcherepanov-nv Jul 16, 2025
c0c12e2
[JAX] Support Flax sharding constraints (#1933)
jberchtold-nvidia Jul 16, 2025
0a1499f
[Pytorch] Dynamo ONNX export support (#1497)
pggPL Jul 16, 2025
bda2993
Handle dtypes more carefully in multi-tensor Adam (#1888)
timmoon10 Jul 16, 2025
fa91ed7
mxfp8 (for all gemm layouts) is not supported on 120+ arch yet (#1939)
sudhakarsingh27 Jul 17, 2025
07afda9
[PyTorch] Add save_original_input in Linear/GroupedLinear to save mem…
hxbai Jul 17, 2025
ed75c2b
[JAX] Tighten Encoder Test tolerances (#1955)
phu0ngng Jul 17, 2025
5350f27
[JAX] Remove unneccessary MXFP8 scale_inv padding (#1954)
phu0ngng Jul 17, 2025
f8933bb
[Common] Optimize KV cache related kernels (#1914)
cyanguwa Jul 17, 2025
657c965
Update cudnn-frontend to 1.13.0 (#1960)
cyanguwa Jul 18, 2025
2d4644b
[JAX] Set `precision=HIGHEST` for the ref_grouped_gemm impl in the un…
phu0ngng Jul 18, 2025
86c5097
[Test] Enable cuDNN Norm tests in the CPP suite (#1957)
phu0ngng Jul 18, 2025
ca7407e
[JAX] Update tolerance of distributed layernorm MLP for FP8 (#1971)
jberchtold-nvidia Jul 19, 2025
b109ff3
[ROCm] merge NV upstream commit ca7407e onto ROCm TE commit 6bbd03c a…
wangye805 Nov 18, 2025
4d3ca4d
[ROCm] resolve the ifu conflicts in common dir
wangye805 Oct 23, 2025
5ce0afd
[ROCm] resolve the conflicts in jax extension
wangye805 Oct 30, 2025
c9c9126
[ROCm] resolve the conflicts in pytorch extension
wangye805 Oct 30, 2025
9730903
[ROCm] resolve the conflicts in setup/build/init
wangye805 Oct 30, 2025
51bdbb8
[ROCm] resolve the conflicts in cpp tests
wangye805 Nov 2, 2025
ba59f81
[ROCm] resolve pytorch pytest conflicts
wangye805 Nov 2, 2025
5842c24
[ROCm] resolve the conflicts in TE jax pytest
wangye805 Nov 6, 2025
c3a9517
[ROCm] fix the example conflict and address reviewer comments
wangye805 Nov 21, 2025
aaceb5a
[ROCm] merge dev to commit 653b5b4
wangye805 Nov 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
418 changes: 156 additions & 262 deletions transformer_engine/common/CMakeLists.txt

Large diffs are not rendered by default.

42 changes: 0 additions & 42 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,6 @@

"""FW agnostic user-end APIs"""

<<<<<<< HEAD
import functools
import sys
import glob
import sysconfig
import subprocess
=======
>>>>>>> ca7407e
import ctypes
import functools
import glob
Expand All @@ -30,7 +22,6 @@

import transformer_engine


_logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -119,40 +110,9 @@ def _get_shared_object_file(library: str) -> Path:
if so_path is not None:
return so_path

<<<<<<< HEAD
# Case 1: Typical user workflow: Both locations are the same, return any result.
if te_install_dir == site_packages_dir:
if so_path_in_install_dir is not None:
return so_path_in_install_dir
raise FileNotFoundError(f"Could not find shared object file for Transformer Engine {library} lib.")

# Case 2: ERR! Both locations are different but returned a valid result.
# NOTE: Unlike for source installations, pip does not wipe out artifacts from
# editable builds. In case developers are executing inside a TE directory via
# an inplace build, and then move to a regular build, the local shared object
# file will be incorrectly picked up without the following logic.
if so_path_in_install_dir is not None and so_path_in_default_dir is not None:
raise RuntimeError(
f"Found multiple shared object files: {so_path_in_install_dir} and"
f" {so_path_in_default_dir}. Remove local shared objects installed"
f" here {so_path_in_install_dir} or change the working directory to"
"execute from outside TE."
)

# Case 3: Typical dev workflow: Editable install
if so_path_in_install_dir is not None:
return so_path_in_install_dir

# Case 4: Executing from inside a TE directory without an inplace build available.
if so_path_in_default_dir is not None:
return so_path_in_default_dir

raise FileNotFoundError(f"Could not find shared object file for Transformer Engine {library} lib.")
=======
raise FileNotFoundError(
f"Could not find shared object file for Transformer Engine {library} lib."
)
>>>>>>> ca7407e


@functools.lru_cache(maxsize=None)
Expand All @@ -178,7 +138,6 @@ def load_framework_extension(framework: str) -> None:
# If the framework extension pip package is installed, it means that TE is installed via
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching version.
'''
if _is_pip_package_installed(module_name):
assert _is_pip_package_installed(
"transformer_engine"
Expand All @@ -197,7 +156,6 @@ def load_framework_extension(framework: str) -> None:
f" v{version(f'transformer-engine-{te_cuda_vers}')}. Install transformer-engine using "
f"'pip3 install transformer-engine[{extra_dep_name}]==VERSION'"
)
'''

# If the core package is installed via PyPI, log if
# the framework extension is not found from PyPI.
Expand Down
30 changes: 9 additions & 21 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@

#ifndef __HIP_PLATFORM_AMD__
#include <cudaTypedefs.h>
<<<<<<< HEAD
#endif //#ifndef __HIP_PLATFORM_AMD__
=======
#define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080)
#else
#define FP4_TYPE_SUPPORTED false
#endif //#ifndef __HIP_PLATFORM_AMD__

>>>>>>> ca7407e
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
Expand Down Expand Up @@ -301,19 +300,16 @@ using fp8e4m3 = __nv_fp8_e4m3;
using fp8e5m2 = __nv_fp8_e5m2;
#if CUDA_VERSION >= 12080
using fp8e8m0 = __nv_fp8_e8m0;
<<<<<<< HEAD
#endif // CUDA_VERSION >= 12080
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif //FP4_TYPE_SUPPORTED
#else
using bf16 = hip_bfloat16;
using fp8e4m3 = te_hip_fp8_e4m3;
using fp8e5m2 = te_hip_fp8_e5m2;
#endif //__HIP_PLATFORM_AMD__
=======
#endif
#if FP4_TYPE_SUPPORTED
using fp4e2m1 = __nv_fp4_e2m1;
#endif
>>>>>>> ca7407e

using e8m0_t = uint8_t;

namespace detail {
Expand Down Expand Up @@ -341,15 +337,11 @@ TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e4m3)
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e5m2)
#if CUDA_VERSION >= 12080
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp8_e8m0)
<<<<<<< HEAD
#endif // CUDA_VERSION >= 12080
#endif // #ifdef __HIP_PLATFORM_AMD__
=======
#endif
#if FP4_TYPE_SUPPORTED
TRANSFORMER_ENGINE_TYPE_NAME(__nv_fp4_e2m1)
#endif
>>>>>>> ca7407e
#endif // #ifdef __HIP_PLATFORM_AMD__
#undef TRANSFORMER_ENGINE_TYPE_NAME

template <typename T>
Expand Down Expand Up @@ -741,12 +733,8 @@ CUtensorMapDataType get_CUtensorMapDataType(DType dtype);
void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor,
const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY,
const uint32_t shmemX, const uint32_t stride_elems,
<<<<<<< HEAD
const uint32_t offset_elems, const size_t type_size);
#endif //#ifndef __HIP_PLATFORM_AMD__
=======
const uint32_t offset_elems, const size_t type_num_bits);
>>>>>>> ca7407e
#endif //#ifndef __HIP_PLATFORM_AMD__

bool is_supported_by_CC_100();

Expand Down
Loading