Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
80 commits
Select commit Hold shift + click to select a range
93ee022
add all the optimizations
vthumbe1503 Jan 5, 2026
06338bc
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2026
50de9cd
requires_grad optimization
vthumbe1503 Jan 6, 2026
5fee841
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 6, 2026
4c79ac7
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 6, 2026
62b88e1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 6, 2026
99494d7
test if commenting out requires_grad works
vthumbe1503 Jan 7, 2026
b157f85
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
2a7b627
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 7, 2026
b61a6a8
fix minor bug
vthumbe1503 Jan 7, 2026
938651e
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 7, 2026
88dfdbd
fix ci
vthumbe1503 Jan 11, 2026
1526eea
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 11, 2026
5809dcc
missed a bug
vthumbe1503 Jan 11, 2026
b3bd748
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 11, 2026
30fecf2
Update transformer_engine/pytorch/csrc/quantizer.cpp
vthumbe1503 Jan 11, 2026
1b0d497
fix some bugs pointed to by copilot
vthumbe1503 Jan 11, 2026
138b7bf
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 11, 2026
eec1e86
linting error
vthumbe1503 Jan 11, 2026
8169d9c
fix the error
vthumbe1503 Jan 12, 2026
6fefaf2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 12, 2026
a5feaf9
fix the bug
vthumbe1503 Jan 13, 2026
285dbff
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 13, 2026
afb2f23
get rid of the change
vthumbe1503 Jan 13, 2026
3919cb8
fix the transpose shape bug
vthumbe1503 Jan 13, 2026
fd36424
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 13, 2026
4668133
minor linter fix
vthumbe1503 Jan 13, 2026
5a00652
fix lint
vthumbe1503 Jan 13, 2026
739bbad
fix linting error
vthumbe1503 Jan 16, 2026
e8042c1
address copilot review comment regarding error check when both data a…
vthumbe1503 Jan 16, 2026
1d323d7
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
da7fbf5
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 16, 2026
e2c7435
fix linting errors
vthumbe1503 Jan 16, 2026
f4e2492
fix merge conflict
vthumbe1503 Jan 16, 2026
beada36
missed a merge conflict
vthumbe1503 Jan 16, 2026
06a72a2
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
5d21db2
final optimizations
vthumbe1503 Jan 16, 2026
1dfd6fe
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 16, 2026
8c8dd20
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2026
c1acd62
fix ci error
vthumbe1503 Jan 18, 2026
7f35b0b
fix merge conflixt
vthumbe1503 Jan 18, 2026
ca177ae
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 18, 2026
1538fd9
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
710b581
address review comment from greptile
vthumbe1503 Jan 18, 2026
8a57a75
fix merge conflixt
vthumbe1503 Jan 18, 2026
7e4f093
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 18, 2026
8604b69
address review comment + stride optimization
vthumbe1503 Jan 19, 2026
de44954
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 19, 2026
cc50745
address linter issue
vthumbe1503 Jan 19, 2026
f2e9a5d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 19, 2026
0d75c3e
minor lint
vthumbe1503 Jan 20, 2026
3d9f673
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Jan 20, 2026
53e8e4e
fix ci bug
vthumbe1503 Jan 20, 2026
c746abd
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 20, 2026
9c922f5
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 20, 2026
88b782a
another optimization to do at::native::empty_cuda directly instead of…
vthumbe1503 Jan 20, 2026
5562cbe
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 20, 2026
14adf1a
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 23, 2026
1e28aa8
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 25, 2026
c651d65
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Jan 26, 2026
e07b5b3
cleanups
vthumbe1503 Feb 23, 2026
3f2da29
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Feb 23, 2026
06ac237
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 23, 2026
24a8f3d
better solution for device
vthumbe1503 Feb 24, 2026
853ddd5
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Feb 24, 2026
3db390d
enum to int cache
vthumbe1503 Feb 24, 2026
aaf5347
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
0bf040f
remove unused function
vthumbe1503 Feb 24, 2026
b7d9693
Update transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
vthumbe1503 Feb 24, 2026
15165b7
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Feb 24, 2026
369afeb
index instead of device bug
vthumbe1503 Feb 24, 2026
cb73444
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
c7bb5ce
fix ci:
vthumbe1503 Feb 24, 2026
f934261
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Feb 24, 2026
1843f02
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Feb 24, 2026
89d8d82
debug quantized tensor fix
vthumbe1503 Feb 27, 2026
63509e6
Merge branch 'cpu_fp8_optimizations' of github.com:vthumbe1503/Transf…
vthumbe1503 Feb 27, 2026
a77195a
Merge branch 'main' into cpu_fp8_optimizations
vthumbe1503 Feb 27, 2026
4e92a46
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 27, 2026
73e4d1d
revert cudnnt front end change
vthumbe1503 Feb 27, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions tests/pytorch/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
scaled_init_method_normal,
)
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx
import transformer_engine_torch as tex
from transformer_engine.pytorch.quantized_tensor import (
Quantizer,
Expand Down Expand Up @@ -2581,12 +2582,12 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker:
_2X_ACC_DGRAD = False
_2X_ACC_WGRAD = False

META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT
META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1
META_O = FP8FwdTensorIdx.GEMM2_INPUT
META_DO = FP8BwdTensorIdx.GRAD_INPUT2
META_S = FP8FwdTensorIdx.GEMM3_OUTPUT
META_DP = FP8BwdTensorIdx.GRAD_INPUT3


class _custom_mha_fp8(torch.autograd.Function):
Expand Down Expand Up @@ -2614,14 +2615,14 @@ def forward(
d = in_features // h
b = cu_seqlens.numel() - 1

input_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
qkv_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT]
qkv_weight_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
o_quantizer = quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
dO_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
dQKV_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
s_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT2]
dP_quantizer = quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT3]
input_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
qkv_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM2_INPUT]
qkv_weight_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
o_quantizer = quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
dO_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
dQKV_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
s_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT2]
dP_quantizer = quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT3]

inp_fp8 = input_quantizer(inp)

Expand Down
21 changes: 11 additions & 10 deletions tests/pytorch/test_custom_recipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import transformer_engine.pytorch as te
import transformer_engine_torch as tex
from transformer_engine.common import recipe
from transformer_engine.pytorch.constants import FP8BwdTensorIdx, FP8FwdTensorIdx
from transformer_engine.pytorch import (
autocast,
Linear,
Expand Down Expand Up @@ -169,11 +170,11 @@ def test_custom_recipe_matches_current_scaling():
with autocast(enabled=True, recipe=ref_recipe):
out_ref = model_ref(inp_ref)
# Assert dtypes for reference quantizers: HYBRID = E4M3 (fwd), E5M2 (bwd)
ref_fwd_in = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
ref_fwd_w = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
ref_fwd_out = model_ref.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
ref_bwd_go = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
ref_bwd_gi = model_ref.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
ref_fwd_in = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
ref_fwd_w = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
ref_fwd_out = model_ref.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
ref_bwd_go = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
ref_bwd_gi = model_ref.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
assert ref_fwd_in.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_w.dtype == tex.DType.kFloat8E4M3
assert ref_fwd_out.dtype == tex.DType.kFloat8E4M3
Expand All @@ -200,11 +201,11 @@ def quantizer_factory(role):
with autocast(enabled=True, recipe=custom_recipe):
out_custom = model_custom(inp_custom)
# Assert dtypes for custom quantizers match reference mapping
cus_fwd_in = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT]
cus_fwd_w = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT]
cus_fwd_out = model_custom.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_OUTPUT]
cus_bwd_go = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_OUTPUT1]
cus_bwd_gi = model_custom.quantizers["scaling_bwd"][tex.FP8BwdTensors.GRAD_INPUT1]
cus_fwd_in = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_INPUT]
cus_fwd_w = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_WEIGHT]
cus_fwd_out = model_custom.quantizers["scaling_fwd"][FP8FwdTensorIdx.GEMM1_OUTPUT]
cus_bwd_go = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_OUTPUT1]
cus_bwd_gi = model_custom.quantizers["scaling_bwd"][FP8BwdTensorIdx.GRAD_INPUT1]
assert cus_fwd_in.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_w.dtype == tex.DType.kFloat8E4M3
assert cus_fwd_out.dtype == tex.DType.kFloat8E4M3
Expand Down
12 changes: 8 additions & 4 deletions transformer_engine/common/gemm/cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,10 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
// Set conditions for MXFP8 and NVFP4 gemm execution.
const auto nvfp4 = is_nvfp_scaling(A.scaling_mode) && is_nvfp_scaling(B.scaling_mode);
const auto mxfp8 = !nvfp4 && is_mxfp_scaling(A.scaling_mode) && is_mxfp_scaling(B.scaling_mode);
int is_nvte_non_tn_fp8_gemm_supported = 0; // needed only for per tensor scaling
if (is_tensor_scaling(A.scaling_mode) || is_tensor_scaling(B.scaling_mode)) {
is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported();
}

// Configure A matrix
if (is_tensor_scaling(A.scaling_mode)) {
Expand All @@ -129,7 +133,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Atype = A.data.dtype;
ret.A_scale_inv = A.scale_inv.dptr;
ret.lda = is_A_transposed ? k : m;
if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) {
ret.A = A.columnwise_data.dptr;
Expand All @@ -140,7 +144,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype),
Expand Down Expand Up @@ -220,7 +224,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
ret.Btype = B.data.dtype;
ret.B_scale_inv = B.scale_inv.dptr;
ret.ldb = is_B_transposed ? n : k;
if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) {
if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) {
// Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data.
if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) {
ret.B = B.columnwise_data.dptr;
Expand All @@ -231,7 +235,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla
} else {
NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage");
}
} else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) {
} else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) {
// Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed
// data with the mirrored transpose-flag if we don't have row-wise data.
NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype),
Expand Down
21 changes: 20 additions & 1 deletion transformer_engine/common/util/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

#include <cuda.h>

#include <mutex>
#include <string>
#include <unordered_map>

#include "../common.h"
#include "../util/string.h"
Expand All @@ -29,13 +31,30 @@ void *get_symbol(const char *symbol, int cuda_version = 12010);
* without GPUs. Indirect function calls into a lazily-initialized
* library ensures we are accessing the correct version.
*
* Symbol pointers are cached to avoid repeated lookups.
*
* \param[in] symbol Function name
* \param[in] args Function arguments
*/
template <typename... ArgTs>
inline CUresult call(const char *symbol, ArgTs... args) {
using FuncT = CUresult(ArgTs...);
FuncT *func = reinterpret_cast<FuncT *>(get_symbol(symbol));

static std::unordered_map<std::string, void *> symbol_cache;
static std::mutex cache_mutex;
FuncT *func;

{
std::lock_guard<std::mutex> lock(cache_mutex);
auto it = symbol_cache.find(symbol);
if (it == symbol_cache.end()) {
void *ptr = get_symbol(symbol);
symbol_cache[symbol] = ptr;
func = reinterpret_cast<FuncT *>(ptr);
} else {
func = reinterpret_cast<FuncT *>(it->second);
}
}
return (*func)(args...);
}

Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/debug/pytorch/debug_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -697,3 +697,12 @@ def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None
raise RuntimeError(
"Cannot recreate columnwise tensor from rowwise tensor is debug mode."
)

@property
def device(self):
"""Return the device of the tensor. Define this to avoid expensive PyObject lookups."""
if self.rowwise_gemm_tensor is not None:
return self.rowwise_gemm_tensor.device
if self.columnwise_gemm_tensor is not None:
return self.columnwise_gemm_tensor.device
raise RuntimeError("DebugQuantizedTensor has no data!")
20 changes: 20 additions & 0 deletions transformer_engine/pytorch/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# See LICENSE for license information.

"""Enums for e2e transformer"""
from types import SimpleNamespace
import torch
import torch.distributed
import transformer_engine_torch as tex
Expand Down Expand Up @@ -40,6 +41,25 @@
tex.DType.kBFloat16: torch.bfloat16,
}

# Cache enum -> int conversions to avoid repeated PyObject lookups.
FP8FwdTensorIdx = SimpleNamespace(
GEMM1_INPUT=int(tex.FP8FwdTensors.GEMM1_INPUT),
GEMM1_WEIGHT=int(tex.FP8FwdTensors.GEMM1_WEIGHT),
GEMM1_OUTPUT=int(tex.FP8FwdTensors.GEMM1_OUTPUT),
GEMM2_INPUT=int(tex.FP8FwdTensors.GEMM2_INPUT),
GEMM2_WEIGHT=int(tex.FP8FwdTensors.GEMM2_WEIGHT),
GEMM2_OUTPUT=int(tex.FP8FwdTensors.GEMM2_OUTPUT),
GEMM3_OUTPUT=int(tex.FP8FwdTensors.GEMM3_OUTPUT),
)
FP8BwdTensorIdx = SimpleNamespace(
GRAD_INPUT1=int(tex.FP8BwdTensors.GRAD_INPUT1),
GRAD_INPUT2=int(tex.FP8BwdTensors.GRAD_INPUT2),
GRAD_INPUT3=int(tex.FP8BwdTensors.GRAD_INPUT3),
GRAD_OUTPUT1=int(tex.FP8BwdTensors.GRAD_OUTPUT1),
GRAD_OUTPUT2=int(tex.FP8BwdTensors.GRAD_OUTPUT2),
GRAD_OUTPUT3=int(tex.FP8BwdTensors.GRAD_OUTPUT3),
)

AttnMaskTypes = (
"no_mask",
"padding",
Expand Down
13 changes: 7 additions & 6 deletions transformer_engine/pytorch/cpp_extensions/fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
NVTE_Fused_Attn_Backend,
)
from ..quantized_tensor import Quantizer
from ..constants import FP8BwdTensorIdx, FP8FwdTensorIdx


__all__ = [
Expand Down Expand Up @@ -103,12 +104,12 @@
BACKEND_F16m512_FP8_THREADS_PER_CTA = 128
BACKEND_F16arb_ELTS_PER_THREADS = 16

META_QKV = tex.FP8FwdTensors.GEMM1_OUTPUT
META_DQKV = tex.FP8BwdTensors.GRAD_OUTPUT1
META_O = tex.FP8FwdTensors.GEMM2_INPUT
META_DO = tex.FP8BwdTensors.GRAD_INPUT2
META_S = tex.FP8FwdTensors.GEMM3_OUTPUT
META_DP = tex.FP8BwdTensors.GRAD_INPUT3
META_QKV = FP8FwdTensorIdx.GEMM1_OUTPUT
META_DQKV = FP8BwdTensorIdx.GRAD_OUTPUT1
META_O = FP8FwdTensorIdx.GEMM2_INPUT
META_DO = FP8BwdTensorIdx.GRAD_INPUT2
META_S = FP8FwdTensorIdx.GEMM3_OUTPUT
META_DP = FP8BwdTensorIdx.GRAD_INPUT3


def fused_attn_fwd(
Expand Down
26 changes: 2 additions & 24 deletions transformer_engine/pytorch/cpp_extensions/gemm.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,28 +67,6 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float:
return 0.0


def get_tensor_device(tensor: torch.Tensor) -> int:
"""
Returns tensor device as an integer.

This method is used because checking instances of
QuantizedTensor or Storage incurs more CPU overhead.
The order of attributes checked is important to also
minimize overhead.
"""
if hasattr(tensor, "device"):
return tensor.device.index
if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None:
return tensor._rowwise_data.device.index
if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None:
return tensor._columnwise_data.device.index
if hasattr(tensor, "_data") and tensor._data is not None:
return tensor._data.device.index
if hasattr(tensor, "_transpose") and tensor._transpose is not None:
return tensor._transpose.device.index
return torch.cuda.current_device()


def general_gemm(
A: torch.Tensor,
B: torch.Tensor,
Expand Down Expand Up @@ -117,7 +95,7 @@ def general_gemm(

alpha = validate_gemm_scale(alpha, True)
beta = validate_gemm_scale(beta, accumulate)
workspace = get_cublas_workspace(get_tensor_device(A), ub is not None, False)
workspace = get_cublas_workspace(A.device.index, ub is not None, False)

if ub_type is not None:
assert ub is not None, (
Expand Down Expand Up @@ -235,7 +213,7 @@ def general_grouped_gemm(
out_dtype = TE_DType[out[0].dtype] if D_dtype is None else D_dtype

sm_count = get_sm_count()
workspaces = get_cublas_workspace(get_tensor_device(A[0]), False, True)
workspaces = get_cublas_workspace(A[0].device.index, False, True)

if grad and use_bias:
grad_bias = [
Expand Down
17 changes: 8 additions & 9 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,10 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr;
PyTypeObject *NVFP4TensorPythonClass = nullptr;
PyTypeObject *NVFP4TensorStoragePythonClass = nullptr;
PyTypeObject *NVFP4QuantizerClass = nullptr;
std::once_flag extension_init_flag;
PyTypeObject *GroupedTensorStoragePythonClass = nullptr;

void init_float8_extension() {
if (Float8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor");
Float8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer"));
Expand All @@ -55,7 +55,6 @@ void init_float8_extension() {
}

void init_mxfp8_extension() {
if (MXFP8TensorPythonClass) return;
auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor");
MXFP8QuantizerClass =
reinterpret_cast<PyTypeObject *>(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer"));
Expand All @@ -70,7 +69,6 @@ void init_mxfp8_extension() {
}

void init_float8blockwise_extension() {
if (Float8BlockwiseQTensorStoragePythonClass) return;
auto fp8_module =
py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor");
auto fp8_base_module = py::module_::import(
Expand All @@ -91,7 +89,6 @@ void init_float8blockwise_extension() {
}

void init_nvfp4_extensions() {
if (NVFP4TensorPythonClass) return;
auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor");
NVFP4QuantizerClass = reinterpret_cast<PyTypeObject *>(
PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer"));
Expand All @@ -116,11 +113,13 @@ void init_grouped_tensor_extension() {
}

void init_extension() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
init_grouped_tensor_extension();
std::call_once(extension_init_flag, []() {
init_float8_extension();
init_mxfp8_extension();
init_float8blockwise_extension();
init_nvfp4_extensions();
init_grouped_tensor_extension();
});
}

} // namespace transformer_engine::pytorch
Expand Down
Loading
Loading