Skip to content
Open
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
93ee022
add all the optimizations
vthumbe1503 Jan 5, 2026
06338bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2026
50de9cd
requires_grad optimization
vthumbe1503 Jan 6, 2026
5fee841
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 6, 2026
4c79ac7
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 6, 2026
62b88e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2026
99494d7
test if commenting out requires_grad works
vthumbe1503 Jan 7, 2026
b157f85
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
2a7b627
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 7, 2026
b61a6a8
fix minor bug
vthumbe1503 Jan 7, 2026
938651e
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
88dfdbd
fix ci
vthumbe1503 Jan 11, 2026
1526eea
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 11, 2026
5809dcc
missed a bug
vthumbe1503 Jan 11, 2026
b3bd748
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 11, 2026
30fecf2
Update transformer_engine/pytorch/csrc/quantizer.cpp
vthumbe1503 Jan 11, 2026
1b0d497
fix some bugs pointed to by copilot
vthumbe1503 Jan 11, 2026
138b7bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2026
eec1e86
linting error
vthumbe1503 Jan 11, 2026
8169d9c
fix the error
vthumbe1503 Jan 12, 2026
6fefaf2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
a5feaf9
fix the bug
vthumbe1503 Jan 13, 2026
285dbff
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 13, 2026
afb2f23
get rid of the change
vthumbe1503 Jan 13, 2026
3919cb8
fix the transpose shape bug
vthumbe1503 Jan 13, 2026
fd36424
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 13, 2026
4668133
minor linter fix
vthumbe1503 Jan 13, 2026
5a00652
fix lint
vthumbe1503 Jan 13, 2026
739bbad
fix linting error
vthumbe1503 Jan 16, 2026
e8042c1
address copilot review comment regarding error check when both data a…
vthumbe1503 Jan 16, 2026
1d323d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
da7fbf5
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 16, 2026
e2c7435
fix linting errors
vthumbe1503 Jan 16, 2026
f4e2492
fix merge conflict
vthumbe1503 Jan 16, 2026
beada36
missed a merge conflict
vthumbe1503 Jan 16, 2026
06a72a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
5d21db2
final optimizations
vthumbe1503 Jan 16, 2026
1dfd6fe
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 16, 2026
8c8dd20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
c1acd62
fix ci error
vthumbe1503 Jan 18, 2026
7f35b0b
fix merge conflixt
vthumbe1503 Jan 18, 2026
ca177ae
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 18, 2026
1538fd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
710b581
address review comment from greptile
vthumbe1503 Jan 18, 2026
8a57a75
fix merge conflixt
vthumbe1503 Jan 18, 2026
7e4f093
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
8604b69
address review comment + stride optimization
vthumbe1503 Jan 19, 2026
de44954
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 19, 2026
cc50745
address linter issue
vthumbe1503 Jan 19, 2026
f2e9a5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
0d75c3e
minor lint
vthumbe1503 Jan 20, 2026
3d9f673
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 20, 2026
53e8e4e
fix ci bug
vthumbe1503 Jan 20, 2026
c746abd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2026
9c922f5
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 20, 2026
88b782a
another optimization to do at::native::empty_cuda directly instead of…
vthumbe1503 Jan 20, 2026
5562cbe
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 20, 2026
14adf1a
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 23, 2026
1e28aa8
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 25, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Set conditions for MXFP8 and NVFP4 gemm execution.
const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);
int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling
if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) {
is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
}
Comment on lines +123 to +126
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

caching nvte_is_non_tn_fp8_gemm_supported() result avoids redundant calls throughout the GEMM configuration for both A and B matrices. Clean optimization with proper scoping

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!


// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
Expand All @@ -129,7 +133,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +224,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making num_devices static has a subtle initialization order issue. The static initialization happens once when the function is first called, but the subsequent static vectors cache and flags depend on num_devices for their size.

If transformer_engine::cuda::num_devices() returns different values across multiple calls (which shouldn't happen in practice but isn't guaranteed by the API), the first call to this function will initialize num_devices, and subsequent calls will use that cached value. However, if the CUDA context changes or devices are added/removed (in rare scenarios), this could cause a mismatch.

Consider:

static int num_devices = transformer_engine::cuda::num_devices();

This is initialized once, but cache and flags vectors might need a different size if the device count somehow changes. While unlikely, this could cause out-of-bounds access.

A safer approach might be:

static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
  cache.resize(num_devices, -1);
  flags.resize(num_devices);
});

Or simply document that the device count must not change during the application's lifetime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Potential out-of-bounds access with static num_devices

Making num_devices static creates a serious bug when GPUs are hot-plugged after the first call to this function. The cache and flags vectors are sized based on the initial device count, but device_id from current_device() could exceed num_devices if GPUs are added later.

This will cause out-of-bounds access on lines 968 and 975:

std::call_once(flags[device_id], ...);  // OOB if device_id >= num_devices
return cache[device_id];                 // OOB if device_id >= num_devices

Impact: Undefined behavior, potential crashes, memory corruption.

Fix: Query device count each time, or add bounds checking:

int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::mutex resize_mutex;

int device_id = transformer_engine::cuda::current_device();
{
  std::lock_guard<std::mutex> lock(resize_mutex);
  if (device_id >= cache.size()) {
    cache.resize(device_id + 1, -1);
    flags.resize(device_id + 1);
  }
}
std::call_once(flags[device_id], [&]() { ... });

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential race condition in static variable initialization

Making num_devices static improves performance by caching the device count, but this introduces a subtle race condition. While the cache and flags vectors are protected by std::call_once, num_devices itself is initialized on first function call without synchronization.

Scenario:

  1. Thread A enters function first time, starts calling transformer_engine::cuda::num_devices()
  2. Thread B enters function while A is still in num_devices(), sees uninitialized static
  3. Both threads may initialize num_devices, causing undefined behavior

However, this is likely safe in practice because:

  • The function is typically called during module initialization (single-threaded)
  • The value doesn't change after initialization
  • Most compilers guarantee thread-safe static initialization in C++11

But for correctness, consider wrapping in the existing std::call_once:

static int num_devices = []() {
  return transformer_engine::cuda::num_devices();
}();

Or initialize it within the std::call_once block if device_id doesn't need num_devices.

static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.

While dynamic device changes are rare, consider adding a bounds check:

Suggested change
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: Potential out-of-bounds access when device_id >= num_devices

Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):

  1. Problem: CUDA device IDs are not necessarily contiguous. For example, with CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2.
  2. Bug scenario: If num_devices is initialized to 2, but later code calls current_device() which returns 2, then cache[device_id] at line 975 will access cache[2] which is out of bounds (cache size is 2, valid indices are 0-1).
  3. Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.

Fix: Either:

  • Use current_device() as the cache key and bounds-check it
  • Map actual device IDs to cache indices
  • Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)

int device_id = transformer_engine::cuda::current_device();
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
Expand All @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
return tensor.device.index
return torch.cuda.current_device()


Expand Down
15 changes: 7 additions & 8 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
std::once_flag extension_init_flag;

void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Expand All @@ -54,7 +54,6 @@ void init_float8_extension() {
}

void init_mxfp8_extension() {
if (MXFP8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
Expand All @@ -69,7 +68,6 @@ void init_mxfp8_extension() {
}

void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
Expand All @@ -90,7 +88,6 @@ void init_float8blockwise_extension() {
}

void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
Expand All @@ -105,10 +102,12 @@ void init_nvfp4_extensions() {
}

void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
});
}
Comment on lines +105 to 111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using std::call_once ensures thread-safe initialization. However, the individual init functions (init_float8_extension, etc.) previously had null-check guards that were removed.

if these functions are ever called directly (not through init_extension()), they'll re-import modules and reassign global pointers without protection. Verify they're only called through init_extension()


} // namespace transformer_engine::pytorch
Expand Down
Loading
Loading