Skip to content

Commit 28d08a7

Browse files
committed
Fix merge conflicts and review suggestions
Update copyright years. Tweak comments. Fix various complaints from @greptile-apps. Signed-off-by: Tim Moon <[email protected]>
1 parent 583e948 commit 28d08a7

File tree

8 files changed

+22
-18
lines changed

8 files changed

+22
-18
lines changed

qa/L0_pytorch_debug_unittest/test.sh

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -28,16 +28,16 @@ mkdir -p "$XML_LOG_DIR"
2828

2929
pip install pytest==8.2.1 || error_exit "Failed to install pytest"
3030

31-
pytest -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py"
32-
pytest -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py"
33-
pytest -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py"
34-
pytest -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py"
35-
NVTE_TORCH_COMPILE=0 pytest -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py"
36-
pytest -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py"
31+
pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py"
32+
pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py"
33+
pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py"
34+
pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py"
35+
NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py"
36+
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py"
3737

3838
# standard sanity and numerics tests with initialized debug
39-
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
40-
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py"
39+
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
40+
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py"
4141

4242
if [ "$RET" -ne 0 ]; then
4343
echo "Error in the following test cases:$FAILED_CASES"

transformer_engine/common/common.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,10 @@ struct Tensor {
133133

134134
NVTEScalingMode scaling_mode;
135135
NVTETensor nvte_tensor;
136-
/*! Whether scaling factors are in format expected by GEMM */
136+
/*! \brief Whether scaling factors are in format expected by GEMM
137+
*
138+
* Only meaningful for MXFP8 and NVFP4.
139+
*/
137140
bool with_gemm_swizzled_scales = false;
138141

139142
/*! Map from NVTETensorParam to parameter sizes */

transformer_engine/common/include/transformer_engine/swizzle.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
************************************************************************/
66

77
/*! \file cast.h
8-
* \brief Functions to cast to/from FP8.
8+
* \brief Functions to convert scaling factors into format expected by GEMM.
99
*/
1010

1111
#ifndef TRANSFORMER_ENGINE_SWIZZLE_H_

transformer_engine/common/transpose/transpose.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*************************************************************************
2-
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
*
44
* See LICENSE for license information.
55
************************************************************************/

transformer_engine/pytorch/csrc/common.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,7 @@ inline at::ScalarType GetATenDType(transformer_engine::DType t) {
387387
case transformer_engine::DType::kFloat8E5M2:
388388
return at::kFloat8_e5m2;
389389
case transformer_engine::DType::kFloat8E8M0:
390-
return at::kByte;
390+
return at::kByte; // e8m0 dtype requires PyTorch 2.7.0+
391391
default:
392392
NVTE_ERROR("Invalid type (", static_cast<int>(t), ").");
393393
}

transformer_engine/pytorch/csrc/extensions/swizzle.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*************************************************************************
2-
* Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
* Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
*
44
* See LICENSE for license information.
55
************************************************************************/

transformer_engine/pytorch/distributed.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1093,15 +1093,16 @@ def _start_all_gather_fp8_blockwise(
10931093
out_shape[0] *= world_size
10941094

10951095
# Check that quantizer is valid
1096-
if quantizer is not None and not isinstance(quantizer, Float8BlockQuantizer):
1096+
if quantizer is None:
1097+
raise ValueError("Quantizer is missing")
1098+
if not isinstance(quantizer, Float8BlockQuantizer):
10971099
raise ValueError(f"Got non-FP8 blockwise quantizer ({quantizer.__class__.__name__})")
10981100

10991101
# Fall back to high-precision all-gather if FP8 is not supported
1100-
if quantizer is None or not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
1102+
if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
11011103
out = torch.empty(out_shape, dtype=dtype, device=device)
11021104
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
1103-
if quantizer is not None:
1104-
out = quantizer(out)
1105+
out = quantizer(out)
11051106
return out, None
11061107

11071108
# Quantize input tensor if needed

transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ def update_usage(
269269
self._columnwise_data = None
270270
self._columnwise_scale_inv = None
271271

272-
def get_usages(self) -> Tuple[bool, bool]:
272+
def get_usages(self) -> Dict[str, bool]:
273273
"""Get the usage of the tensor"""
274274
return {
275275
"rowwise": self._rowwise_data is not None,

0 commit comments

Comments
 (0)