-
Notifications
You must be signed in to change notification settings - Fork 603
[PyTorch Debug] NVFP4 debug stats support #2296
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
|
/te-ci pytorch |
|
|
||
| @Registry.register_feature(namespace="transformer_engine") | ||
| class DisableFP8Layer: | ||
| class DisableFP8Layer(DisableQuantizationLayer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It may be worth raising a deprecation warning in the constructor or something. DisableFP8GEMM would also benefit from this.
| return total_elements - first_zeros - second_zeros | ||
|
|
||
|
|
||
| def add_nvfp4_underflows_stats(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With RHT, it's possible that the NVFP4 data has fewer zeros than the high-precision data, so this stat would be negative. It's not possible yet (we currently only apply RHT to the NVFP4 column-wise data) and it's probably beyond the scope of this PR, but it's something to consider if we ever generalize.
for more information, see https://pre-commit.ci
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds NVFP4 statistics support (underflows% and MSE) to the debug features and refactors FP8-specific naming to quantization-agnostic terminology. The implementation includes a new LogNvfp4TensorStats feature class, NVFP4-specific statistics computation functions, and comprehensive tests.
Major changes:
- New
LogNvfp4TensorStatsfeature for logging NVFP4 tensor statistics (underflows%, mse) - New quantization-agnostic features
DisableQuantizationGEMMandDisableQuantizationLayer - Deprecated
DisableFP8GEMMandDisableFP8Layerwith proper warnings and backward compatibility - Refactored variable names from "fp8" to "quantization" throughout debug infrastructure
- Added NVFP4-specific stat computation functions (
count_nonzero_nvfp4,add_nvfp4_underflows_stats) - Updated
LogFp8TensorStatsto detect NVFP4 quantizers and provide clear error messages - Comprehensive test coverage for NVFP4 numerics and what-if analysis scenarios
Issues found:
- Minor: Error message inconsistency in
log_nvfp4_tensor_stats.pyline 169 (says "QuantizedTensor" but checks for "NVFP4TensorStorage")
Confidence Score: 4/5
- This PR is safe to merge with minor style improvements recommended
- The implementation is well-structured with proper error handling, comprehensive tests, and good backward compatibility. The NVFP4 statistics computation logic appears correct after thorough analysis of the nibble unpacking and zero-value detection. The refactoring from FP8-specific to quantization-agnostic naming is clean and consistent. Only one minor style issue was found (error message inconsistency). The deprecation warnings are properly implemented with clear migration paths.
- No files require special attention - all changes are well-tested and properly implemented
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| transformer_engine/debug/features/log_nvfp4_tensor_stats.py | 4/5 | New file adds NVFP4 tensor statistics logging. Minor error message inconsistency found. |
| transformer_engine/debug/features/utils/stats_computation.py | 3/5 | Adds NVFP4-specific stats computation. Function appears correct after thorough analysis. |
| transformer_engine/debug/features/log_fp8_tensor_stats.py | 4/5 | Updated to handle NVFP4 quantizer detection. Good defensive error handling added. |
| transformer_engine/debug/features/disable_quantization_gemm.py | 5/5 | New quantization-agnostic feature, clean implementation with backward compatibility. |
| transformer_engine/debug/features/disable_quantization_layer.py | 5/5 | New quantization-agnostic feature, clean implementation with proper naming. |
| tests/pytorch/debug/test_log.py | 5/5 | Comprehensive NVFP4 tests added with proper numeric validation. |
Sequence Diagram
sequenceDiagram
participant User as Training Script
participant API as debug_api
participant Feature as LogNvfp4TensorStats
participant Buffer as STATS_BUFFERS
participant Compute as stats_computation
User->>API: autocast(recipe=NVFP4BlockScaling())
User->>User: forward/backward pass
User->>API: debug_api.step()
Note over API,Feature: For each layer with NVFP4 quantization
API->>Feature: inspect_tensor_enabled()
Feature->>Feature: check_if_stat_is_supported()
Feature-->>API: return (run_current, next_iter)
alt Feature enabled for this iteration
API->>Feature: inspect_tensor()
Feature->>Feature: verify NVFP4Quantizer & NVFP4TensorStorage
Feature->>Feature: get_stat_with_prefix() (add nvfp4_ prefix)
Feature->>Buffer: try_add_buffer()
Feature->>Feature: update_aux_dict()
Note over Feature: aux_dict = {"nvfp4": quantized_tensor, "original_tensor": tensor}
Feature->>Buffer: feed(aux_dict)
Buffer->>Compute: compute nvfp4_underflows%
Compute->>Compute: count_nonzero_nvfp4(packed_fp4_data)
Note over Compute: Unpack nibbles, count non-zeros
Buffer->>Compute: compute nvfp4_mse
Compute->>Compute: dequantize & MSE loss
Buffer-->>Feature: stats collected
end
API->>Buffer: flush statistics to log
Buffer-->>User: statistics logged to file
| assert isinstance( | ||
| quantized_tensor, NVFP4TensorStorage | ||
| ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a QuantizedTensor." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
error message inconsistency: says "quantized_tensor must be a QuantizedTensor" but the type check is for NVFP4TensorStorage
| assert isinstance( | |
| quantized_tensor, NVFP4TensorStorage | |
| ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a QuantizedTensor." | |
| assert isinstance( | |
| quantized_tensor, NVFP4TensorStorage | |
| ), "[NVTORCH INSPECT ERROR] LogNvfp4TensorStats quantized_tensor must be a NVFP4TensorStorage." |
Description
This PR adds support for NVFP4 statistics: underflows and mse. I add them in seperate feature, because we may want to have a lot nvfp4-specific features added later.
Also, I renamed few variables from "fp8"-like to "quantization"-like. I cannot rename all of them - for example "is_fp8_gemm_enabled" which is an API call, so I left some of them.
Fixes # (issue)
Type of change
Checklist: