Skip to content
Open
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
60 commits
Select commit Hold shift + click to select a range
93ee022
add all the optimizations
vthumbe1503 Jan 5, 2026
06338bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2026
50de9cd
requires_grad optimization
vthumbe1503 Jan 6, 2026
5fee841
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 6, 2026
4c79ac7
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 6, 2026
62b88e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2026
99494d7
test if commenting out requires_grad works
vthumbe1503 Jan 7, 2026
b157f85
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
2a7b627
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 7, 2026
b61a6a8
fix minor bug
vthumbe1503 Jan 7, 2026
938651e
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
88dfdbd
fix ci
vthumbe1503 Jan 11, 2026
1526eea
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 11, 2026
5809dcc
missed a bug
vthumbe1503 Jan 11, 2026
b3bd748
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 11, 2026
30fecf2
Update transformer_engine/pytorch/csrc/quantizer.cpp
vthumbe1503 Jan 11, 2026
1b0d497
fix some bugs pointed to by copilot
vthumbe1503 Jan 11, 2026
138b7bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2026
eec1e86
linting error
vthumbe1503 Jan 11, 2026
8169d9c
fix the error
vthumbe1503 Jan 12, 2026
6fefaf2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
a5feaf9
fix the bug
vthumbe1503 Jan 13, 2026
285dbff
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 13, 2026
afb2f23
get rid of the change
vthumbe1503 Jan 13, 2026
3919cb8
fix the transpose shape bug
vthumbe1503 Jan 13, 2026
fd36424
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 13, 2026
4668133
minor linter fix
vthumbe1503 Jan 13, 2026
5a00652
fix lint
vthumbe1503 Jan 13, 2026
739bbad
fix linting error
vthumbe1503 Jan 16, 2026
e8042c1
address copilot review comment regarding error check when both data a…
vthumbe1503 Jan 16, 2026
1d323d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
da7fbf5
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 16, 2026
e2c7435
fix linting errors
vthumbe1503 Jan 16, 2026
f4e2492
fix merge conflict
vthumbe1503 Jan 16, 2026
beada36
missed a merge conflict
vthumbe1503 Jan 16, 2026
06a72a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
5d21db2
final optimizations
vthumbe1503 Jan 16, 2026
1dfd6fe
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 16, 2026
8c8dd20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
c1acd62
fix ci error
vthumbe1503 Jan 18, 2026
7f35b0b
fix merge conflixt
vthumbe1503 Jan 18, 2026
ca177ae
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 18, 2026
1538fd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
710b581
address review comment from greptile
vthumbe1503 Jan 18, 2026
8a57a75
fix merge conflixt
vthumbe1503 Jan 18, 2026
7e4f093
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
8604b69
address review comment + stride optimization
vthumbe1503 Jan 19, 2026
de44954
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 19, 2026
cc50745
address linter issue
vthumbe1503 Jan 19, 2026
f2e9a5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
0d75c3e
minor lint
vthumbe1503 Jan 20, 2026
3d9f673
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 20, 2026
53e8e4e
fix ci bug
vthumbe1503 Jan 20, 2026
c746abd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2026
9c922f5
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 20, 2026
88b782a
another optimization to do at::native::empty_cuda directly instead of…
vthumbe1503 Jan 20, 2026
5562cbe
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 20, 2026
14adf1a
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 23, 2026
1e28aa8
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 25, 2026
c651d65
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same scope - it was already declared at line 224 for the B matrix configuration. This causes a compilation error or shadowing issue.

The variable should only be declared once at the beginning of the function scope and reused for both A and B matrix configurations.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move this declaration before line 125 (Configure A matrix)
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();

if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable is_nvte_non_tn_fp8_gemm_supported is redeclared in the same function scope (already declared at line 132). In C++, this is a compilation error. The variable should be declared once at the function scope level before both the A matrix and B matrix configuration blocks.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Move declaration to line 104 (after GemmParam ret;) to avoid redeclaration

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is already declared at line 132. Declaring it again here causes a compilation error. Remove this declaration and reuse the variable from line 132.

Suggested change
int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
// Reuse is_nvte_non_tn_fp8_gemm_supported from line 132

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[P0] Variable is_nvte_non_tn_fp8_gemm_supported is declared twice in the same function scope (first at line 132). This will cause a compilation error. Remove this redeclaration and reuse the variable from line 132.

if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) {
}

int nvte_is_non_tn_fp8_gemm_supported() {
int num_devices = transformer_engine::cuda::num_devices();
static int num_devices = transformer_engine::cuda::num_devices();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Making num_devices static has a subtle initialization order issue. The static initialization happens once when the function is first called, but the subsequent static vectors cache and flags depend on num_devices for their size.

If transformer_engine::cuda::num_devices() returns different values across multiple calls (which shouldn't happen in practice but isn't guaranteed by the API), the first call to this function will initialize num_devices, and subsequent calls will use that cached value. However, if the CUDA context changes or devices are added/removed (in rare scenarios), this could cause a mismatch.

Consider:

static int num_devices = transformer_engine::cuda::num_devices();

This is initialized once, but cache and flags vectors might need a different size if the device count somehow changes. While unlikely, this could cause out-of-bounds access.

A safer approach might be:

static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::once_flag init_flag;
std::call_once(init_flag, [&]() {
  cache.resize(num_devices, -1);
  flags.resize(num_devices);
});

Or simply document that the device count must not change during the application's lifetime.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical: Potential out-of-bounds access with static num_devices

Making num_devices static creates a serious bug when GPUs are hot-plugged after the first call to this function. The cache and flags vectors are sized based on the initial device count, but device_id from current_device() could exceed num_devices if GPUs are added later.

This will cause out-of-bounds access on lines 968 and 975:

std::call_once(flags[device_id], ...);  // OOB if device_id >= num_devices
return cache[device_id];                 // OOB if device_id >= num_devices

Impact: Undefined behavior, potential crashes, memory corruption.

Fix: Query device count each time, or add bounds checking:

int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache;
static std::vector<std::once_flag> flags;
static std::mutex resize_mutex;

int device_id = transformer_engine::cuda::current_device();
{
  std::lock_guard<std::mutex> lock(resize_mutex);
  if (device_id >= cache.size()) {
    cache.resize(device_id + 1, -1);
    flags.resize(device_id + 1);
  }
}
std::call_once(flags[device_id], [&]() { ... });

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential race condition in static variable initialization

Making num_devices static improves performance by caching the device count, but this introduces a subtle race condition. While the cache and flags vectors are protected by std::call_once, num_devices itself is initialized on first function call without synchronization.

Scenario:

  1. Thread A enters function first time, starts calling transformer_engine::cuda::num_devices()
  2. Thread B enters function while A is still in num_devices(), sees uninitialized static
  3. Both threads may initialize num_devices, causing undefined behavior

However, this is likely safe in practice because:

  • The function is typically called during module initialization (single-threaded)
  • The value doesn't change after initialization
  • Most compilers guarantee thread-safe static initialization in C++11

But for correctness, consider wrapping in the existing std::call_once:

static int num_devices = []() {
  return transformer_engine::cuda::num_devices();
}();

Or initialize it within the std::call_once block if device_id doesn't need num_devices.

static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Potential edge case: The num_devices is captured once as a static variable, so if the number of devices changes at runtime (e.g., with CUDA MPS or device enumeration changes), the cache and flags vectors could be undersized, leading to out-of-bounds access at line 968.

While dynamic device changes are rare, consider adding a bounds check:

Suggested change
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
static int num_devices = transformer_engine::cuda::num_devices();
static std::vector<int> cache(num_devices, -1);
static std::vector<std::once_flag> flags(num_devices);
int device_id = transformer_engine::cuda::current_device();
NVTE_CHECK(device_id < num_devices, "Device ID exceeds cached device count");

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Critical bug: Potential out-of-bounds access when device_id >= num_devices

Making num_devices static causes it to be evaluated only once at first call. This creates a serious bug if the device ID is not within the range [0, num_devices):

  1. Problem: CUDA device IDs are not necessarily contiguous. For example, with CUDA_VISIBLE_DEVICES=0,2, you have 2 visible devices but device IDs are 0 and 2.
  2. Bug scenario: If num_devices is initialized to 2, but later code calls current_device() which returns 2, then cache[device_id] at line 975 will access cache[2] which is out of bounds (cache size is 2, valid indices are 0-1).
  3. Impact: This will cause undefined behavior, likely a segmentation fault or memory corruption.

Fix: Either:

  • Use current_device() as the cache key and bounds-check it
  • Map actual device IDs to cache indices
  • Remove the static keyword and re-evaluate num_devices on each call (sacrifices some of the optimization but ensures correctness)

int device_id = transformer_engine::cuda::current_device();
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
Expand All @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
if hasattr(tensor, "device"):
Comment on lines 79 to 87
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reordering optimizes for quantized tensors but creates performance regression for regular tensors. Before, regular torch.Tensor objects checked device immediately (1 attribute check). Now they check 4 non-existent quantized attributes via hasattr() first.

if get_tensor_device() is frequently called with regular tensors, consider adding an early type check:

if type(tensor).__name__ == 'Tensor':
    return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()

return tensor.device.index
Comment on lines 79 to 88
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Performance regression for regular tensors

The reordering of attribute checks optimizes for quantized tensors but creates a performance regression for regular torch.Tensor objects.

Before: Regular tensors check device immediately (line 1) and return
After: Regular tensors check 4 non-existent attributes via hasattr() before checking device

Impact:

  • Quantized tensors: ~4 fewer attribute checks ✓
  • Regular tensors: ~4 additional attribute checks ✗

Concern: If regular tensors are passed to get_tensor_device() frequently (which seems likely given the function name), this could offset the gains from other optimizations in this PR.

Recommendation: Profile both paths or add an early isinstance check:

def get_tensor_device(tensor: torch.Tensor) -> int:
    # Fast path for regular tensors
    if type(tensor).__name__ == 'Tensor':
        return tensor.device.index if hasattr(tensor, 'device') else torch.cuda.current_device()
    
    # Quantized tensor path
    if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
        return tensor._rowwise_data.device.index
    # ... etc

return torch.cuda.current_device()
Comment on lines 79 to 89
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reordering device attribute checks changes behavior for regular tensors

The original code checked tensor.device FIRST (line 78-79 in old code), then checked quantized tensor attributes. The new code reverses this order, checking quantized attributes first.

Impact:

  • Performance: For regular torch.Tensor objects, this INCREASES overhead by attempting to check for _rowwise_data, _columnwise_data, _data, and _transpose attributes (all of which will fail via hasattr) before finally checking device.
  • Correctness: For QuantizedTensors that happen to have a .device attribute that differs from their internal data device, this could change which device is returned.

Concern: This optimization seems to hurt performance for the common case (regular tensors) while only helping the less common case (quantized tensors).

Question: What is the relative frequency of this function being called with regular tensors vs quantized tensors? If regular tensors are more common, this change may degrade overall performance.



Expand Down
7 changes: 3 additions & 4 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
bool is_extension_initialized = false;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Race condition: The is_extension_initialized flag is not thread-safe. If init_extension() is called concurrently from multiple threads, both threads could pass the check at line 105 before either sets the flag at line 110, causing the initialization code to run multiple times.

The original implementation using null-checks on the individual class pointers was safer because Python's module import system provides thread safety. Consider using std::once_flag and std::call_once for proper thread-safe initialization:

Suggested change
bool is_extension_initialized = false;
bool is_extension_initialized = false;
std::once_flag extension_init_flag;

And in init_extension():

void init_extension() {
  std::call_once(extension_init_flag, []() {
    init_float8_extension();
    init_mxfp8_extension();
    init_float8blockwise_extension();
    init_nvfp4_extensions();
  });
}


void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Expand All @@ -54,7 +54,6 @@ void init_float8_extension() {
}
Comment on lines 40 to 54
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed null-check guards without proper synchronization guarantee

The original code had null checks (if (Float8TensorPythonClass) return;) to prevent redundant initialization. These were removed in favor of std::call_once, but the individual init_* functions no longer have guards.

Issue: If init_extension() is called multiple times from different threads, std::call_once ensures the lambda runs once. However, if init_float8_extension() is called directly (not through init_extension()), it will re-import modules and reassign global pointers without protection.

Verification needed: Confirm that these functions are NEVER called directly, only through init_extension(). If they can be called directly, the null checks should be restored:

void init_float8_extension() {
  if (Float8TensorPythonClass) return;  // Keep this guard
  auto fp8_module = py::module_::import(...);
  ...
}


void init_mxfp8_extension() {
if (MXFP8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
Expand All @@ -69,7 +68,6 @@ void init_mxfp8_extension() {
}

void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
Expand All @@ -90,7 +88,6 @@ void init_float8blockwise_extension() {
}

void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
Expand All @@ -105,10 +102,12 @@ void init_nvfp4_extensions() {
}

void init_extension() {
if (is_extension_initialized) return;
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
is_extension_initialized = true;
}
Comment on lines +105 to 111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using std::call_once ensures thread-safe initialization. However, the individual init functions (init_float8_extension, etc.) previously had null-check guards that were removed.

if these functions are ever called directly (not through init_extension()), they'll re-import modules and reassign global pointers without protection. Verify they're only called through init_extension()


} // namespace transformer_engine::pytorch
Expand Down
Loading
Loading