Skip to content
2 changes: 1 addition & 1 deletion docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
- log the statistics for each of the tensors in every matrix multiply (GEMM) operation,
- run selected GEMMs in higher precision,
- run current scaling - with one scaling factor per tensor - for particular GEMMs,
- test new precisions and integrate them with FP8 training,
- test new precisions and integrate them with quantized training (FP8, NVFP4, etc.),
- ... and many more.

There are 4 things one needs to do to use Transformer Engine debug features:
Expand Down
7 changes: 5 additions & 2 deletions docs/debug/3_api_features.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ Debug features

.. autoapiclass:: transformer_engine.debug.features.log_tensor_stats.LogTensorStats
.. autoapiclass:: transformer_engine.debug.features.log_fp8_tensor_stats.LogFp8TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
.. autoapiclass:: transformer_engine.debug.features.log_nvfp4_tensor_stats.LogNvfp4TensorStats
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_gemm.DisableQuantizationGEMM
.. autoapiclass:: transformer_engine.debug.features.disable_quantization_layer.DisableQuantizationLayer
.. autoapiclass:: transformer_engine.debug.features.per_tensor_scaling.PerTensorScaling
.. autoapiclass:: transformer_engine.debug.features.fake_quant.FakeQuant
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_gemm.DisableFP8GEMM
.. autoapiclass:: transformer_engine.debug.features.disable_fp8_layer.DisableFP8Layer
120 changes: 120 additions & 0 deletions tests/pytorch/debug/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
is_fp8_available,
is_mxfp8_available,
is_fp8_block_scaling_available,
is_nvfp4_available,
)
from transformer_engine.pytorch.quantization import RecipeState
from transformer_engine.debug.pytorch.debug_state import TEDebugState
Expand All @@ -29,6 +30,7 @@
fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available(
return_reason=True
)
nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True)

LOG_QUANTIZED_CONFIG_BASE = """
log:
Expand Down Expand Up @@ -363,6 +365,124 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs):
TEDebugState._reset()


# NVFP4 tests
LOG_NVFP4_CONFIG_BASE = """
log:
layers:
layer_name_regex_pattern: .*
enabled:
True
transformer_engine:
LogNvfp4TensorStats:
enabled: True
stats: [
{stats}
]
tensors: [activation, gradient, weight]
freq: 2
start_step: 0
end_step: 10
"""


def test_nvfp4_numeric(feature_dirs):
"""Test that NVFP4 underflows% and MSE stats are computed correctly with known values."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)

log_nvfp4_config = LOG_NVFP4_CONFIG_BASE.format(stats="underflows%, mse")

with debug_session(log_nvfp4_config, feature_dirs) as log_dir:
from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer
from transformer_engine.pytorch.quantization import RecipeState

recipe_state = RecipeState.create(
recipe.NVFP4BlockScaling(),
mode="forward",
num_quantizers=3,
)

# Create test tensor with known distribution
torch.manual_seed(42)
tensor = torch.randn(128, 128, dtype=torch.bfloat16).cuda()
# Add some small values that should underflow to zero in FP4
tensor[0, :16] = 0.0001

quantizer = recipe_state.make_quantizers()[0]
quantized_tensor = quantizer(tensor)

debug_api.transformer_engine.inspect_tensor(
layer_name="test_layer",
tensor_name="activation",
iteration=0,
tp_group=None,
tensor=tensor,
quantizer=quantizer,
rowwise_quantized_tensor=quantized_tensor,
columnwise_quantized_tensor=quantized_tensor,
)
debug_api.step()

dequantized_tensor = quantized_tensor.dequantize()
output = read_log(log_dir)

# Validate both stats are present
assert "nvfp4_underflows%" in output, "underflows% stat missing"
assert "nvfp4_mse" in output, "mse stat missing"

# Extract values and validate numerics
underflows_value = None
mse_value = None

for line in output.splitlines():
if "nvfp4_underflows%" in line and "value=" in line:
underflows_value = float(line.split("value=")[1].split()[0])
if "nvfp4_mse" in line and "value=" in line:
mse_value = float(line.split("value=")[1].split()[0])

# Compute expected underflows: non-zero elements that became zero after quantization
orig_nonzero_mask = tensor != 0
dequant_zero_mask = dequantized_tensor == 0
expected_underflows = (
(orig_nonzero_mask & dequant_zero_mask).sum().float() / tensor.numel() * 100
)

# Allow some tolerance
assert underflows_value == pytest.approx(expected_underflows.cpu().item(), abs=1e-4)

# Compute expected MSE
expected_mse = torch.nn.functional.mse_loss(
dequantized_tensor.float(), tensor.float(), reduction="mean"
)

assert mse_value == pytest.approx(expected_mse.cpu().item(), abs=1e-4)


def test_fp8_stats_allows_nvfp4_with_recipe_prefix(feature_dirs):
"""Test that LogFp8TensorStats allows recipe-prefixed stats with NVFP4 for what-if analysis."""
if not nvfp4_available:
pytest.skip(reason_for_no_nvfp4)

# Use recipe-prefixed stat with NVFP4 - should work (computes MXFP8 separately)
log_fp8_config = LOG_QUANTIZED_CONFIG_BASE.format(stats="mxfp8_mse")

with debug_session(log_fp8_config, feature_dirs) as log_dir:
model = te.Linear(128, 128, params_dtype=torch.bfloat16)
inp = torch.randn(128, 128, dtype=torch.bfloat16).cuda()

# Should work - recipe-prefixed stats compute MXFP8 separately for comparison
for _ in range(2):
with te.autocast(recipe=recipe.NVFP4BlockScaling()):
output = model(inp)
loss = output.sum()
loss.backward()
debug_api.step()

output = read_log(log_dir)
# Should have logged MXFP8 MSE stat (what-if scenario)
assert "mxfp8_mse" in output


def test_log_grouped_gemm(feature_dirs):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
Expand Down
44 changes: 25 additions & 19 deletions transformer_engine/debug/features/disable_fp8_gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,28 @@
#
# See LICENSE for license information.

"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect"""
"""DisableFP8GEMM Feature support for nvidia-dlframework-inspect

from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper
DEPRECATED: This is a backward compatibility alias for DisableQuantizationGEMM.
New code should use DisableQuantizationGEMM instead, which works with all quantization formats.
"""

import warnings

from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_gemm import DisableQuantizationGEMM


@Registry.register_feature(namespace="transformer_engine")
class DisableFP8GEMM(TEConfigAPIMapper):
class DisableFP8GEMM(DisableQuantizationGEMM):
"""
GEMM operations are executed in higher precision, even when FP8 autocast is enabled.

.. deprecated::
Use :class:`DisableQuantizationGEMM` instead. This class is maintained for
backward compatibility only. DisableQuantizationGEMM works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.

Parameters
----------

Expand All @@ -32,22 +43,17 @@ class DisableFP8GEMM(TEConfigAPIMapper):
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8GEMM:
DisableFP8GEMM: # Deprecated: use DisableQuantizationGEMM
enabled: True
gemms: [dgrad, wgrad]
"""

@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and FP8 GEMM execution."""

for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')

# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behaviour in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
def __init__(self, *args, **kwargs):
warnings.warn(
"DisableFP8GEMM is deprecated. "
"Use DisableQuantizationGEMM instead, which works with all quantization "
"formats (FP8, NVFP4, etc.).",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
63 changes: 28 additions & 35 deletions transformer_engine/debug/features/disable_fp8_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,54 +2,47 @@
#
# See LICENSE for license information.

"""DisableFP8Layer Feature support for nvidia-dlframework-inspect"""
"""DisableFP8Layer Feature support for nvidia-dlframework-inspect

import nvdlfw_inspect.api as debug_api
from nvdlfw_inspect.registry import Registry, api_method
DEPRECATED: This is a backward compatibility alias for DisableQuantizationLayer.
New code should use DisableQuantizationLayer instead, which works with all quantization formats.
"""

import warnings

from nvdlfw_inspect.registry import Registry
from transformer_engine.debug.features.disable_quantization_layer import DisableQuantizationLayer


@Registry.register_feature(namespace="transformer_engine")
class DisableFP8Layer:
class DisableFP8Layer(DisableQuantizationLayer):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It may be worth raising a deprecation warning in the constructor or something. DisableFP8GEMM would also benefit from this.

"""
Disables all FP8 GEMMs in the layer.

.. deprecated::
Use :class:`DisableQuantizationLayer` instead. This class is maintained for
backward compatibility only. DisableQuantizationLayer works with all quantization
formats (FP8, NVFP4, etc.), not just FP8.

Example
-------
.. code-block:: yaml

example_disable_fp8_layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableFP8Layer: # Deprecated: use DisableQuantizationLayer
enabled: True
"""

@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for selecting between high-precision and FP8 GEMM execution."""
for key in config:
if key not in ["enabled", "gemm"]:
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')
# If FP8 training, disable FP8 for the selected layers if this feature is enabled in config.
debug_api.log_message("FP8 Disabled", layer_name)

# If this feature is invoked, then FP8 GEMM is disabled.
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1

def parse_config_and_api(self, config, **_kwargs):
"""Determines whether to run the API
DisableFP8Layer is the only feature provided by the Transformer Engine
which does not inherit from TEConfigAPIMapper - this mapper is primarly responsible for
parsing gemms and tensors fields from the config, which are not needed for this feature.

Explanation of the parse_config_and_api can be found in the
nvidia-dlframework-inspect documentation.
"""
return config["enabled"], None
def __init__(self, *args, **kwargs):
warnings.warn(
"DisableFP8Layer is deprecated. "
"Use DisableQuantizationLayer instead, which works with all quantization "
"formats (FP8, NVFP4, etc.).",
DeprecationWarning,
stacklevel=2,
)
super().__init__(*args, **kwargs)
59 changes: 59 additions & 0 deletions transformer_engine/debug/features/disable_quantization_gemm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""DisableQuantizationGEMM Feature support for nvidia-dlframework-inspect"""

from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper


@Registry.register_feature(namespace="transformer_engine")
class DisableQuantizationGEMM(TEConfigAPIMapper):
"""
Disables specific GEMM operations from using quantization, forcing high-precision execution.

Works with any quantization format (FP8, NVFP4, etc.).

Parameters
----------

gemms: List[str]
list of gemms to disable quantization for

- fprop
- dgrad
- wgrad

Example
-------
.. code-block:: yaml

example_disable_quantization_gemm:
enabled: True
layers:
layer_types: [fc1]
transformer_engine:
DisableQuantizationGEMM:
enabled: True
gemms: [dgrad, wgrad]
"""

@api_method
def fp8_gemm_enabled(
self, config, layer_name: str, gemm: str, iteration: int
): # pylint: disable=unused-argument
"""API call responsible for choice between high-precision and quantized GEMM execution.

Note: Method name kept as 'fp8_gemm_enabled' for backward compatibility with the debug API,
but it applies to all quantization formats (FP8, NVFP4, etc.).
"""

for key in config:
if key != "gemm":
raise ValueError(f'[NVTORCH INSPECT ERROR] Unexpected key in config: "{key}".')

# If this feature is invoked, then quantized GEMM is disabled (returns to high precision).
# If not, then default behavior in TransformerEngineAPI
# is that fp8_gemm() API call returns True.
return False, iteration + 1
Loading
Loading