Skip to content

Commit 1de4b5e

Browse files
committed
Merge branch 'main' into tmoon/pre-swizzled-scales
2 parents 8b10300 + 5afbb0e commit 1de4b5e

File tree

34 files changed

+3651
-254
lines changed

34 files changed

+3651
-254
lines changed

build_tools/VERSION.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
2.11.0.dev0
1+
2.12.0.dev0

build_tools/pytorch.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
def install_requirements() -> List[str]:
1616
"""Install dependencies for TE/PyTorch extensions."""
17-
return ["torch>=2.1", "einops", "onnxscript", "onnx"]
17+
return ["torch>=2.1", "einops", "onnxscript", "onnx", "packaging", "pydantic"]
1818

1919

2020
def test_requirements() -> List[str]:

build_tools/utils.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,9 @@ def get_cuda_include_dirs() -> Tuple[str, str]:
241241

242242
cuda_root = Path(nvidia.__file__).parent
243243
return [
244-
cuda_root / "cuda_nvcc" / "include",
245-
cuda_root / "cublas" / "include",
246-
cuda_root / "cuda_runtime" / "include",
247-
cuda_root / "cudnn" / "include",
248-
cuda_root / "cuda_cccl" / "include",
249-
cuda_root / "nvtx" / "include",
250-
cuda_root / "cuda_nvrtc" / "include",
244+
subdir / "include"
245+
for subdir in cuda_root.iterdir()
246+
if subdir.is_dir() and (subdir / "include").is_dir()
251247
]
252248

253249

docs/examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb

Lines changed: 255 additions & 0 deletions
Large diffs are not rendered by default.

docs/index.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,4 @@ Transformer Engine documentation
5656
api/c/index
5757
debug
5858
examples/attention/attention.ipynb
59+
examples/attention/cp_ag_thd_dpa_jax_deep_dive.ipynb

qa/L0_jax_distributed_unittest/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ function test_fail() {
1616
RET=0
1717
FAILED_CASES=""
1818

19+
export NVTE_JAX_TEST_TIMING=1
20+
1921
: ${TE_PATH:=/opt/transformerengine}
2022
: ${XML_LOG_DIR:=/logs}
2123
mkdir -p "$XML_LOG_DIR"

qa/L0_jax_unittest/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ function test_fail() {
1818
RET=0
1919
FAILED_CASES=""
2020

21+
export NVTE_JAX_TEST_TIMING=1
22+
2123
pip3 install "nltk>=3.8.2" || error_exit "Failed to install nltk"
2224
pip3 install pytest==8.2.1 || error_exit "Failed to install pytest"
2325

qa/L0_pytorch_unittest/test.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED
3232
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py"
3333
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py"
3434
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_nvfp4.xml $TE_PATH/tests/pytorch/nvfp4 || test_fail "test_nvfp4"
35-
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py"
35+
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_quantized_tensor.xml $TE_PATH/tests/pytorch/test_quantized_tensor.py || test_fail "test_quantized_tensor.py"
3636
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py"
3737
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py"
3838
python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py"

qa/L1_jax_distributed_unittest/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ function test_fail() {
1111
RET=0
1212
FAILED_CASES=""
1313

14+
export NVTE_JAX_TEST_TIMING=1
15+
1416
: ${TE_PATH:=/opt/transformerengine}
1517
: ${XML_LOG_DIR:=/logs}
1618
mkdir -p "$XML_LOG_DIR"

qa/L2_jax_distributed_unittest/test.sh

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
set -xe
66

7+
export NVTE_JAX_TEST_TIMING=1
8+
79
: ${TE_PATH:=/opt/transformerengine}
810
: ${XML_LOG_DIR:=/logs}
911
mkdir -p "$XML_LOG_DIR"

0 commit comments

Comments
 (0)