-
Notifications
You must be signed in to change notification settings - Fork 613
CPU Optimizations for FP8 #2559
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 28 commits
93ee022
06338bc
50de9cd
5fee841
4c79ac7
62b88e1
99494d7
b157f85
2a7b627
b61a6a8
938651e
88dfdbd
1526eea
5809dcc
b3bd748
30fecf2
1b0d497
138b7bf
eec1e86
8169d9c
6fefaf2
a5feaf9
285dbff
afb2f23
3919cb8
fd36424
4668133
5a00652
739bbad
e8042c1
1d323d7
da7fbf5
e2c7435
f4e2492
beada36
06a72a2
5d21db2
1dfd6fe
8c8dd20
c1acd62
7f35b0b
ca177ae
1538fd9
710b581
8a57a75
7e4f093
8604b69
de44954
cc50745
f2e9a5d
0d75c3e
3d9f673
53e8e4e
c746abd
9c922f5
88b782a
5562cbe
14adf1a
1e28aa8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { | |||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
| int nvte_is_non_tn_fp8_gemm_supported() { | ||||||||||||||||||
| int num_devices = transformer_engine::cuda::num_devices(); | ||||||||||||||||||
| static int num_devices = transformer_engine::cuda::num_devices(); | ||||||||||||||||||
|
||||||||||||||||||
| static std::vector<int> cache(num_devices, -1); | ||||||||||||||||||
| static std::vector<std::once_flag> flags(num_devices); | ||||||||||||||||||
|
||||||||||||||||||
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| static int num_devices = transformer_engine::cuda::num_devices(); | |
| static std::vector<int> cache(num_devices, -1); | |
| static std::vector<std::once_flag> flags(num_devices); | |
| int device_id = transformer_engine::cuda::current_device(); | |
| NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count"); |
Outdated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Critical bug: Potential out-of-bounds access when device_id >= num_devices
Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):
- Problem: CUDA device IDs are not necessarily contiguous. For example, with
CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2. - Bug scenario: If
num_devicesis initialized to 2, but later code callscurrent_device()which returns 2, thencache[device_id]at line 975 will accesscache[2]which is out of bounds (cache size is 2, valid indices are 0-1). - Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.
Fix: Either:
- Use
current_device()as the cache key and bounds-check it - Map actual device IDs to cache indices
- Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; | |
| PyTypeObject *NVFP4TensorPythonClass = nullptr; | ||
| PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; | ||
| PyTypeObject *NVFP4QuantizerClass = nullptr; | ||
| std::once_flag extension_init_flag; | ||
|
|
||
| void init_float8_extension() { | ||
| if (Float8TensorPythonClass) return; | ||
| auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); | ||
| Float8QuantizerClass = | ||
| reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); | ||
|
|
@@ -54,7 +54,6 @@ void init_float8_extension() { | |
| } | ||
|
|
||
| void init_mxfp8_extension() { | ||
| if (MXFP8TensorPythonClass) return; | ||
| auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); | ||
| MXFP8QuantizerClass = | ||
| reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); | ||
|
|
@@ -69,7 +68,6 @@ void init_mxfp8_extension() { | |
| } | ||
|
|
||
| void init_float8blockwise_extension() { | ||
| if (Float8BlockwiseQTensorStoragePythonClass) return; | ||
| auto fp8_module = | ||
| py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); | ||
| auto fp8_base_module = py::module_::import( | ||
|
|
@@ -90,7 +88,6 @@ void init_float8blockwise_extension() { | |
| } | ||
|
|
||
| void init_nvfp4_extensions() { | ||
| if (NVFP4TensorPythonClass) return; | ||
| auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); | ||
| NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>( | ||
| PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); | ||
|
|
@@ -105,10 +102,12 @@ void init_nvfp4_extensions() { | |
| } | ||
|
|
||
| void init_extension() { | ||
| init_float8_extension(); | ||
| init_mxfp8_extension(); | ||
| init_float8blockwise_extension(); | ||
| init_nvfp4_extensions(); | ||
| std::call_once(extension_init_flag, []() { | ||
| init_float8_extension(); | ||
| init_mxfp8_extension(); | ||
| init_float8blockwise_extension(); | ||
| init_nvfp4_extensions(); | ||
| }); | ||
vthumbe1503 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| } | ||
|
Comment on lines
+105
to
111
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. using if these functions are ever called directly (not through |
||
|
|
||
| } // namespace transformer_engine::pytorch | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
caching
nvte_is_non_tn_fp8_gemm_supported()result avoids redundant calls throughout the GEMM configuration for both A and B matrices. Clean optimization with proper scopingNote: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!