Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
bf5b217
Changed VERSION to 2.6.0
KshitijLakhani Jul 20, 2025
c7d0271
[PyTorch] Remove GH pinned deps (#1961)
ksivaman Jul 21, 2025
787acff
[PyTorch] Reset FP8 weight workspace if usages are invalid (#1972)
timmoon10 Jul 21, 2025
9926245
Fix the condition error when checking fp8 attn in `get_attention_back…
yuzhongw-nvidia Jul 21, 2025
4b537aa
[Common] Skip cuDNN 9.10.0/9.10.1 due to bugs (#1937)
cyanguwa Jul 21, 2025
7ba6cd5
[PyTorch] Debug linear layer when saving original input and using deb…
timmoon10 Jul 22, 2025
b97c2bf
[Common] Improved performance of mxfp8 cast kernels (#1628)
Oleg-Goncharov Jul 22, 2025
a593092
Fix the device for cuDNN/cuBLAS handles (#1974)
cyanguwa Jul 23, 2025
928dfa8
[JAX] Fix current scaling test_helper.py and enable test_helper.py in…
jberchtold-nvidia Jul 23, 2025
13f5796
[JAX] Helper to disable TE custom calls + disable GemmPrimitive for n…
phu0ngng Jul 24, 2025
e02e289
Fix runtime lib loading for cuDNN (#1989)
ksivaman Jul 24, 2025
21d7410
Fix cudnn versioning support in PyTorch DPA and Fused attn (#1991)
KshitijLakhani Jul 24, 2025
0f585e8
[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parall…
denera Jul 24, 2025
5f1142e
[PyTorch] Optimize cudagraph static_grad_outputs reuse (#1992)
buptzyb Jul 25, 2025
c90a720
Fix the use-after-free bug in unfused normalization (#2002)
ptrendx Jul 29, 2025
966a4ac
Merge remote-tracking branch 'upstream/release_v2.6' into yewang12/if…
wangye805 Jan 3, 2026
97556c6
[ROCm] Resolve conflicts
wangye805 Dec 17, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
from tests.pytorch.utils import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
get_available_attention_backends,
)
from tests.pytorch.attention.test_attention import _run_dot_product_attention

pd.set_option("display.precision", 4)

Expand Down Expand Up @@ -197,7 +197,7 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down
30 changes: 17 additions & 13 deletions benchmarks/attention/benchmark_attention_rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,19 @@
import transformer_engine
from transformer_engine_torch import NVTE_Fused_Attn_Backend

# Add test_fused_attn to the sys path
# Add paths tests/pytorch/ and tests/pytorch/attention to the sys path
tests_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn")
os.path.join(os.path.dirname(__file__), "../../tests")
)
sys.path.append(tests_path)
sys.path.append(tests_path + "/pytorch")
sys.path.append(tests_path + "/pytorch/attention")

from test_fused_attn import (
# Add tests/pytorch/utils.py path into sys path
from utils import (
ModelConfig,
_get_attention_backends,
get_available_attention_backends,
)
from test_attention import (
_run_dot_product_attention,
)

Expand All @@ -46,12 +50,12 @@
is_training = True

model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
"test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias")
# test: b, sq, h, d
"test_0": ModelConfig(2, 512, 16, 64), # short seq
"test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask
"test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias
"test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=8, attn_mask_type="causal"), # GQA
"test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=16, attn_mask_type="causal_bottom_right")
}

# DataFrame indices and columns for results
Expand Down Expand Up @@ -303,7 +307,7 @@ def sanity_checks(
}

for model, cfg in model_configs.items():
avail, _, fused_bes = _get_attention_backends(
avail, _, fused_bes = get_available_attention_backends(
cfg,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down Expand Up @@ -364,7 +368,7 @@ def main(args):
# Benchmarking starts..
for model in model_configs.keys():
config = model_configs[model]
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.6.0.dev0
2.6.0
15 changes: 1 addition & 14 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,7 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
reqs = ["einops"]
if not rocm_build():
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/[email protected]#egg=nvdlfw-inspect"
)
reqs.extend(
[
"torch>=2.1",
"onnx",
"onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return reqs
return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]


def test_requirements() -> List[str]:
Expand Down
6 changes: 3 additions & 3 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ run_test_config(){
run_default_fa 1 test_recipe.py
run 1 test_sanity.py
run_default_fa 1 test_sanity_import.py
run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test
run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test
run_default_fa 1 triton_kernels/test_cast.py
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_norm_common.py
Expand All @@ -88,8 +88,8 @@ run_test_config_mgpu(){
run_default_fa 2 distributed/test_numerics.py
run_default_fa 1 distributed/test_torch_fsdp2.py
run_default_fa 2 distributed/test_torch_fsdp2_fp8.py
run_default_fa_lbl "flash" 3 fused_attn/test_fused_attn_with_cp.py -k "with_flash"
run_default_fa_lbl "fused" 2 fused_attn/test_fused_attn_with_cp.py -k "with_fused"
run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash"
run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused"
}

run_benchmark() {
Expand Down
4 changes: 2 additions & 2 deletions docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
There are 4 things one needs to do to use Transformer Engine debug features:

1. Create a configuration YAML file to configure the desired features.
2. Import, and initialize the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool, which is installed as the dependency of the Transformer Engine.
2. Import, initialize, and install the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool.
3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically.
4. Invoke ``debug_api.step()`` at the end of one forward-backward pass.

Expand Down Expand Up @@ -238,4 +238,4 @@ Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_
.. figure:: ./img/tensorboard.png
:align: center

Fig 2: TensorBoard with plotted stats.
Fig 2: TensorBoard with plotted stats.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import os
import torch
from typing import Tuple
from tests.pytorch.fused_attn.test_fused_attn import ModelConfig
from tests.pytorch.utils import ModelConfig
from transformer_engine.pytorch.attention import DotProductAttention

# Initialize RNG state
Expand Down
18 changes: 9 additions & 9 deletions docs/examples/attention/attention.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@
"\n",
"Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n",
"\n",
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
"For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts."
]
},
{
Expand All @@ -394,10 +394,10 @@
"| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n",
"\n",
"Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n",
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)"
"- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n",
"- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)"
]
},
{
Expand Down Expand Up @@ -458,7 +458,7 @@
" </tr>\n",
"</table>\n",
"\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n",
"\n",
"<div class=\"alert alert-info\">\n",
"<b>Note</b>\n",
Expand Down Expand Up @@ -548,7 +548,7 @@
"id": "dda4a589",
"metadata": {},
"source": [
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n",
"Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n",
"\n",
"### 3.3 Attention Bias\n",
"\n",
Expand Down Expand Up @@ -594,7 +594,7 @@
"\n",
"The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n",
"\n",
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)."
"More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)."
]
},
{
Expand All @@ -612,7 +612,7 @@
"\n",
"- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n",
"\n",
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
"Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`."
]
}
],
Expand Down
8 changes: 4 additions & 4 deletions docs/examples/attention/example_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
from tests.pytorch.utils import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
get_available_attention_backends,
)
from tests.pytorch.attention.test_attention import _run_dot_product_attention

# data type
dtype = torch.bfloat16
Expand Down Expand Up @@ -90,7 +90,7 @@ def main():
models = ["test_0"]
for model in models:
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down
4 changes: 2 additions & 2 deletions qa/L0_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
: ${XML_LOG_DIR:=/logs}
mkdir -p "$XML_LOG_DIR"

python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*"

pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist"
Expand All @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"
NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls"

if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
Expand Down
4 changes: 2 additions & 2 deletions qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py"
NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py"
NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py"
Expand Down
3 changes: 2 additions & 1 deletion qa/L1_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,13 @@ mkdir -p "$XML_LOG_DIR"

pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"

python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py"
python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py"


Expand Down
2 changes: 1 addition & 1 deletion qa/L2_jax_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
# Test without custom calls
export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops"
NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"
NVTE_JAX_CUSTOM_CALLS="false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py"

if [ $RET -ne 0 ]; then
echo "Error: some sub-tests failed: $FAILED_CASES"
Expand Down
Loading