-
Notifications
You must be signed in to change notification settings - Fork 629
CPU Optimizations for FP8 #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 13 commits
93ee022
06338bc
50de9cd
5fee841
4c79ac7
62b88e1
99494d7
b157f85
2a7b627
b61a6a8
938651e
88dfdbd
1526eea
5809dcc
b3bd748
30fecf2
1b0d497
138b7bf
eec1e86
8169d9c
6fefaf2
a5feaf9
285dbff
afb2f23
3919cb8
fd36424
4668133
5a00652
739bbad
e8042c1
1d323d7
da7fbf5
e2c7435
f4e2492
beada36
06a72a2
5d21db2
1dfd6fe
8c8dd20
c1acd62
7f35b0b
ca177ae
1538fd9
710b581
8a57a75
7e4f093
8604b69
de44954
cc50745
f2e9a5d
0d75c3e
3d9f673
53e8e4e
c746abd
9c922f5
88b782a
5562cbe
14adf1a
1e28aa8
c651d65
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla | |||||||||
| ret.Atype = A.data.dtype; | ||||||||||
| ret.A_scale_inv = A.scale_inv.dptr; | ||||||||||
| ret.lda = is_A_transposed ? k : m; | ||||||||||
| if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { | ||||||||||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | ||||||||||
| if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { | ||||||||||
| // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. | ||||||||||
| if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { | ||||||||||
| ret.A = A.columnwise_data.dptr; | ||||||||||
|
|
@@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla | |||||||||
| } else { | ||||||||||
| NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); | ||||||||||
| } | ||||||||||
| } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { | ||||||||||
| } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { | ||||||||||
| // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed | ||||||||||
| // data with the mirrored transpose-flag if we don't have row-wise data. | ||||||||||
| NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), | ||||||||||
|
|
@@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla | |||||||||
| ret.Btype = B.data.dtype; | ||||||||||
| ret.B_scale_inv = B.scale_inv.dptr; | ||||||||||
| ret.ldb = is_B_transposed ? n : k; | ||||||||||
| if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { | ||||||||||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | ||||||||||
|
||||||||||
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.
| int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); | |
| // Reuse is_nvte_non_tn_fp8_gemm_supported from line 132 |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[P0] Variable is_nvte_non_tn_fp8_gemm_supported is declared twice in the same function scope (first at line 132). This will cause a compilation error. Remove this redeclaration and reuse the variable from line 132.
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { | |||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||||||||||||||||||
| int num_devices = transformer_engine::cuda::num_devices(); | ||||||||||||||||||
| static int num_devices = transformer_engine::cuda::num_devices(); | ||||||||||||||||||
|
||||||||||||||||||
| static std::vector<int> cache(num_devices, -1); | ||||||||||||||||||
| static std::vector<std::once_flag> flags(num_devices); | ||||||||||||||||||
|
||||||||||||||||||
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: Potential out-of-bounds access when device_id >= num_devices
Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):
- Problem: CUDA device IDs are not necessarily contiguous. For example, with
CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2. - Bug scenario: If
num_devicesis initialized to 2, but later code callscurrent_device()which returns 2, thencache[device_id]at line 975 will accesscache[2]which is out of bounds (cache size is 2, valid indices are 0-1). - Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.
Fix: Either:
- Use
current_device()as the cache key and bounds-check it - Map actual device IDs to cache indices
- Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int: | |
| The order of attributes checked is important to also | ||
| minimize overhead. | ||
| """ | ||
| if hasattr(tensor, "device"): | ||
| return tensor.device.index | ||
| if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: | ||
| return tensor._rowwise_data.device.index | ||
| if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: | ||
|
|
@@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int: | |
| return tensor._data.device.index | ||
| if hasattr(tensor, "_transpose") and tensor._transpose is not None: | ||
| return tensor._transpose.device.index | ||
| if hasattr(tensor, "device"): | ||
|
Comment on lines
79
to
87
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. reordering optimizes for quantized tensors but creates performance regression for regular tensors. Before, regular if if type(tensor).__name__ == 'Tensor':
return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device() |
||
| return tensor.device.index | ||
|
Comment on lines
79
to
88
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Performance regression for regular tensors The reordering of attribute checks optimizes for quantized tensors but creates a performance regression for regular Before: Regular tensors check Impact:
Concern: If regular tensors are passed to Recommendation: Profile both paths or add an early isinstance check: def get_tensor_device(tensor: torch.Tensor) -> int:
# Fast path for regular tensors
if type(tensor).__name__ == 'Tensor':
return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()
# Quantized tensor path
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
# ... etc |
||
| return torch.cuda.current_device() | ||
vthumbe1503 marked this conversation as resolved.
Show resolved
Hide resolved
vthumbe1503 marked this conversation as resolved.
Show resolved
Hide resolved
Comment on lines
79
to
89
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Reordering device attribute checks changes behavior for regular tensors The original code checked Impact:
Concern: This optimization seems to hurt performance for the common case (regular tensors) while only helping the less common case (quantized tensors). Question: What is the relative frequency of this function being called with regular tensors vs quantized tensors? If regular tensors are more common, this change may degrade overall performance. |
||
|
|
||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change | ||||||
|---|---|---|---|---|---|---|---|---|
|
|
@@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; | |||||||
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||||||||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||||||||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||||||||
| bool is_extension_initialized = false; | ||||||||
|
||||||||
| bool is_extension_initialized = false; | |
| bool is_extension_initialized = false; | |
| std::once_flag extension_init_flag; |
And in init_extension():
void init_extension() {
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removed null-check guards without proper synchronization guarantee
The original code had null checks (if (Float8TensorPythonClass) return;) to prevent redundant initialization. These were removed in favor of std::call_once, but the individual init_* functions no longer have guards.
Issue: If init_extension() is called multiple times from different threads, std::call_once ensures the lambda runs once. However, if init_float8_extension() is called directly (not through init_extension()), it will re-import modules and reassign global pointers without protection.
Verification needed: Confirm that these functions are NEVER called directly, only through init_extension(). If they can be called directly, the null checks should be restored:
void init_float8_extension() {
if (Float8TensorPythonClass) return; // Keep this guard
auto fp8_module = py::module_::import(...);
...
}There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
using std::call_once ensures thread-safe initialization. However, the individual init functions (init_float8_extension, etc.) previously had null-check guards that were removed.
if these functions are ever called directly (not through init_extension()), they'll re-import modules and reassign global pointers without protection. Verify they're only called through init_extension()
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Variable
is_nvte_non_tn_fp8_gemm_supportedis redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.