diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 4b7d8179b0..85d89bff8e 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -129,7 +129,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Atype = A.data.dtype; ret.A_scale_inv = A.scale_inv.dptr; ret.lda = is_A_transposed ? k : m; - if (!nvte_is_non_tn_fp8_gemm_supported() && !is_A_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && !is_A_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype)) { ret.A = A.columnwise_data.dptr; @@ -140,7 +141,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Atype), "Input A is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !A.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !A.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(A.has_columnwise_data() && is_fp8_dtype(A.columnwise_data.dtype), @@ -220,7 +221,8 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla ret.Btype = B.data.dtype; ret.B_scale_inv = B.scale_inv.dptr; ret.ldb = is_B_transposed ? n : k; - if (!nvte_is_non_tn_fp8_gemm_supported() && is_B_transposed) { + int is_nvte_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + if (!is_nvte_non_tn_fp8_gemm_supported && is_B_transposed) { // Hopper only supports TN GEMMs for FP8. "Column-wise data" is transpose of data. if (B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype)) { ret.B = B.columnwise_data.dptr; @@ -231,7 +233,7 @@ GemmParam CanonicalizeGemmInput(const transformer_engine::Tensor &A, const cubla } else { NVTE_CHECK(!is_fp8_dtype(ret.Btype), "Input B is missing column-wise usage"); } - } else if (nvte_is_non_tn_fp8_gemm_supported() && !B.has_data()) { + } else if (is_nvte_non_tn_fp8_gemm_supported && !B.has_data()) { // Blackwell supports any GEMM layout for FP8, so we can use column-wise/transposed // data with the mirrored transpose-flag if we don't have row-wise data. NVTE_CHECK(B.has_columnwise_data() && is_fp8_dtype(B.columnwise_data.dtype), diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 370d9723cf..82c50c4ebd 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -961,7 +961,7 @@ void nvte_destroy_quantization_config(NVTEQuantizationConfig config) { } int nvte_is_non_tn_fp8_gemm_supported() { - int num_devices = transformer_engine::cuda::num_devices(); + static int num_devices = transformer_engine::cuda::num_devices(); static std::vector cache(num_devices, -1); static std::vector flags(num_devices); int device_id = transformer_engine::cuda::current_device(); diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 2a97e2ac71..7fe37b5f54 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -76,8 +76,6 @@ def get_tensor_device(tensor: torch.Tensor) -> int: The order of attributes checked is important to also minimize overhead. """ - if hasattr(tensor, "device"): - return tensor.device.index if hasattr(tensor, "_rowwise_data") and tensor._rowwise_data is not None: return tensor._rowwise_data.device.index if hasattr(tensor, "_columnwise_data") and tensor._columnwise_data is not None: @@ -86,6 +84,8 @@ def get_tensor_device(tensor: torch.Tensor) -> int: return tensor._data.device.index if hasattr(tensor, "_transpose") and tensor._transpose is not None: return tensor._transpose.device.index + if hasattr(tensor, "device"): + return tensor.device.index return torch.cuda.current_device() diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index e73eca7861..4b14e8c019 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,9 +35,9 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +bool is_extension_initialized = false; void init_float8_extension() { - if (Float8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_tensor"); Float8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "Float8Quantizer")); @@ -54,7 +54,6 @@ void init_float8_extension() { } void init_mxfp8_extension() { - if (MXFP8TensorPythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.mxfp8_tensor"); MXFP8QuantizerClass = reinterpret_cast(PyObject_GetAttrString(fp8_module.ptr(), "MXFP8Quantizer")); @@ -69,7 +68,6 @@ void init_mxfp8_extension() { } void init_float8blockwise_extension() { - if (Float8BlockwiseQTensorStoragePythonClass) return; auto fp8_module = py::module_::import("transformer_engine.pytorch.tensor.float8_blockwise_tensor"); auto fp8_base_module = py::module_::import( @@ -90,7 +88,6 @@ void init_float8blockwise_extension() { } void init_nvfp4_extensions() { - if (NVFP4TensorPythonClass) return; auto nvfp4_module = py::module_::import("transformer_engine.pytorch.tensor.nvfp4_tensor"); NVFP4QuantizerClass = reinterpret_cast( PyObject_GetAttrString(nvfp4_module.ptr(), "NVFP4Quantizer")); @@ -105,10 +102,12 @@ void init_nvfp4_extensions() { } void init_extension() { + if (is_extension_initialized) return; init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); init_nvfp4_extensions(); + is_extension_initialized = true; } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a73efc008a..b5612f6632 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -121,9 +121,9 @@ std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Initialize data tensor - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -134,7 +134,7 @@ std::pair Float8Quantizer::create_tensor( py::object data_py = with_data ? py::cast(*data) : py::none(); // Initialize transpose tensor - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -143,26 +143,52 @@ std::pair Float8Quantizer::create_tensor( transpose.reset(); } py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); - + py::object scale_inv_py = py::cast(scale_inv); // Initialize scale-inverse tensor if (!scale_inv) { scale_inv = at::reciprocal(scale); } - + at::Device device = + with_data ? data->device() : (with_transpose ? transpose->device() : torch::kCUDA); // Construct Python FP8 tensor py::object out_py; if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -185,10 +211,10 @@ std::pair Float8Quantizer::create_tensor( std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); // Extract buffers from Python tensor @@ -328,7 +354,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize data tensor at::Tensor data_tensor; - const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); + const bool with_data = rowwise_usage || is_non_tn_fp8_gemm_supported; if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); @@ -337,13 +364,12 @@ std::pair Float8CurrentScalingQuantizer::create_tenso // Initialize transpose tensor at::Tensor transpose_tensor; - const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool with_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); transpose_tensor = at::empty(transpose_shape, opts); } - // Initialize scale-inverse tensor at::Tensor scale_inv_tensor; { @@ -351,23 +377,49 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); scale_inv_tensor = at::empty(scale_inv_shape, opts); } - + at::Device device = with_data ? data_tensor.device() + : (with_transpose ? transpose_tensor.device() : torch::kCUDA); // Construct Python FP8 tensor py::object out_py; + py::object scale_inv_py = py::cast(scale_inv_tensor); py::object data_py = with_data ? py::cast(data_tensor) : py::none(); py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorStoragePythonClass)); - out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + + PyObject* result = PyObject_Call(reinterpret_cast(Float8TensorStoragePythonClass), + PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); const std::vector shape_int64(shape.begin(), shape.end()); - out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "data", data_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_scale_inv", scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "data_transpose", transpose_py.ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "device", py::cast(device).ptr()); + PyObject* result = + PyObject_Call(reinterpret_cast(Float8TensorPythonClass), PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ FP8 tensor @@ -406,10 +458,10 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8CurrentScalingQuantizer must output to Float8Tensor."); - + int is_non_tn_fp8_gemm_supported = nvte_is_non_tn_fp8_gemm_supported(); // Expected buffers - const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); - const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + const bool need_data = rowwise_usage || is_non_tn_fp8_gemm_supported; + const bool need_transpose = columnwise_usage && !is_non_tn_fp8_gemm_supported; NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); // Extract buffers from Python tensor @@ -629,22 +681,46 @@ std::pair Float8BlockQuantizer::create_tensor( py::object ret; if (internal) { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass)); - ret = Float8BlockwiseQTensorClass( - "rowwise_data"_a = data_rowwise, "columnwise_data"_a = data_colwise, - "rowwise_scale_inv"_a = scale_inv_rowwise, "columnwise_scale_inv"_a = scale_inv_colwise, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer, - "is_2D_scaled"_a = (block_scaling_dim == 2), "data_format"_a = data_format); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); + + PyObject* result = + PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorStoragePythonClass), + PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensorStorage instance"); + ret = py::reinterpret_steal(result); } else { - py::handle Float8BlockwiseQTensorClass( - reinterpret_cast(Float8BlockwiseQTensorPythonClass)); - ret = Float8BlockwiseQTensorClass( - "shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), "rowwise_data"_a = data_rowwise, - "columnwise_data"_a = data_colwise, "rowwise_scale_inv"_a = scale_inv_rowwise, - "columnwise_scale_inv"_a = scale_inv_colwise, "fp8_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer, "is_2D_scaled"_a = (block_scaling_dim == 2), - "data_format"_a = data_format); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(torch_shape).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", py::cast(data_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", py::cast(data_colwise).ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", py::cast(scale_inv_rowwise).ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", py::cast(scale_inv_colwise).ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + PyDict_SetItemString(kwargs, "is_2D_scaled", py::cast(block_scaling_dim == 2).ptr()); + PyDict_SetItemString(kwargs, "data_format", py::cast(data_format).ptr()); + + PyObject* result = PyObject_Call(reinterpret_cast(Float8BlockwiseQTensorPythonClass), + PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create Float8BlockwiseQTensor instance"); + ret = py::reinterpret_steal(result); } return {std::move(tensor), std::move(ret)}; @@ -950,20 +1026,41 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Construct Python MXFP8 tensor py::object out_py; if (internal) { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorStoragePythonClass)); - out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = PyObject_Call(reinterpret_cast(MXFP8TensorStoragePythonClass), + PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, - "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "fp8_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = + PyObject_Call(reinterpret_cast(MXFP8TensorPythonClass), PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create MXFP8Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ MXFP8 tensor @@ -1234,22 +1331,45 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Construct Python NVFP4 tensor py::object out_py; if (internal) { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorStoragePythonClass)); - out_py = NVFP4TensorClass( - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); + PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorStoragePythonClass), + PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4TensorStorage instance"); + out_py = py::reinterpret_steal(result); } else { - py::handle NVFP4TensorClass(reinterpret_cast(NVFP4TensorPythonClass)); - out_py = NVFP4TensorClass( - "shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = rowwise_data_py, "columnwise_data"_a = columnwise_data_py, - "rowwise_scale_inv"_a = rowwise_scale_inv_py, - "columnwise_scale_inv"_a = columnwise_scale_inv_py, "amax_rowwise"_a = amax_rowwise_py, - "amax_columnwise"_a = amax_columnwise_py, "fp4_dtype"_a = this->dtype, - "quantizer"_a = this->quantizer); + // Use direct C API call bypassing pybind11 overhead + PyObject* kwargs = PyDict_New(); + PyDict_SetItemString(kwargs, "shape", py::cast(shape_int64).ptr()); + PyDict_SetItemString(kwargs, "dtype", py::cast(GetATenDType(dtype)).ptr()); + PyDict_SetItemString(kwargs, "rowwise_data", rowwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_data", columnwise_data_py.ptr()); + PyDict_SetItemString(kwargs, "rowwise_scale_inv", rowwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "columnwise_scale_inv", columnwise_scale_inv_py.ptr()); + PyDict_SetItemString(kwargs, "amax_rowwise", amax_rowwise_py.ptr()); + PyDict_SetItemString(kwargs, "amax_columnwise", amax_columnwise_py.ptr()); + PyDict_SetItemString(kwargs, "fp4_dtype", py::cast(this->dtype).ptr()); + PyDict_SetItemString(kwargs, "quantizer", this->quantizer.ptr()); + + PyObject* result = + PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), PyTuple_New(0), kwargs); + + Py_DECREF(kwargs); + + NVTE_CHECK(result != nullptr, "Failed to create NVFP4Tensor instance"); + out_py = py::reinterpret_steal(result); } // Construct C++ tensor diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ad5cd04341..368b61b382 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -929,12 +929,11 @@ def set_activation_dtype(self, inp: torch.Tensor) -> None: if torch.is_autocast_enabled(): self.activation_dtype = torch_get_autocast_gpu_dtype() return - + dtype = inp.dtype # All checks after this have already been performed once, thus skip - if self.activation_dtype == inp.dtype: + if self.activation_dtype == dtype: return - dtype = inp.dtype if not self.allow_different_data_and_param_types: for name, param in self.named_parameters(): if param is not None: diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f3220d5860..a49652a2c2 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -93,7 +93,6 @@ def forward( non_tensor_args: Tuple, ) -> torch.Tensor: # pylint: disable=missing-function-docstring - ( is_first_microbatch, fp8, @@ -130,6 +129,10 @@ def forward( debug, ) = non_tensor_args + inp_requires_grad = inp.requires_grad + weight_requires_grad = weight.requires_grad + bias_requires_grad = bias.requires_grad if bias is not None else False + # NVTX label for profiling nvtx_label = "transformer_engine._Linear.forward" if ub_name is not None: @@ -141,7 +144,7 @@ def forward( # Configure tensor-parallel communication tp_world_size = get_distributed_world_size(tp_group) - backward_needs_input = is_grad_enabled and weight.requires_grad + backward_needs_input = is_grad_enabled and weight_requires_grad with_input_all_gather_nccl = ( parallel_mode == "column" and sequence_parallel and not ub_overlap_ag_fprop ) @@ -254,7 +257,7 @@ def forward( # Configure quantizer # No need to set the quantizer states if weight is already quantized if weight_quantizer is not None and not isinstance(weight, QuantizedTensor): - columnwise_usage = is_grad_enabled and inp.requires_grad + columnwise_usage = is_grad_enabled and inp_requires_grad if not columnwise_usage: columnwise_usage = ( is_fp8_activation_recompute_enabled() @@ -379,7 +382,7 @@ def forward( ctx.weight_quantizer = weight_quantizer ctx.backward_input_needs_gather = ( - weight.requires_grad and parallel_mode == "column" and sequence_parallel + weight_requires_grad and parallel_mode == "column" and sequence_parallel ) # Discard unneeded data in input tensor @@ -447,7 +450,7 @@ def forward( ctx.grad_weight_quantizer = grad_weight_quantizer ctx.grad_output_quantizer = grad_output_quantizer ctx.fuse_wgrad_accumulation = fuse_wgrad_accumulation - if fuse_wgrad_accumulation and weight.requires_grad: + if fuse_wgrad_accumulation and weight_requires_grad: # This check is needed to ensure that main_grad is not created # during the forward pass when using MCore FSDP as it creates # the main_grad buffer lazily before backprop @@ -473,12 +476,12 @@ def forward( ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_name = ub_name ctx.tp_size = tp_size - ctx.requires_dgrad = inp.requires_grad - ctx.requires_wgrad = weight.requires_grad + ctx.requires_dgrad = inp_requires_grad + ctx.requires_wgrad = weight_requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False ctx.owns_input = saved_inputmat is not inp - if ctx.fp8 and requires_grad(inp, weight, bias): + if ctx.fp8 and inp_requires_grad and weight_requires_grad and bias_requires_grad: _first_fp8_module = FP8GlobalStateManager.IS_FIRST_FP8_MODULE ctx.reduce_and_update_bwd_fp8_tensors = FP8GlobalStateManager.is_first_fp8_module() if in_fp8_activation_recompute_phase(): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 3414581f7c..62fa9b1114 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -360,9 +360,38 @@ def __new__( requires_grad=requires_grad, device=torch.cuda.current_device() if device is None else device, ) - + instance._dtype = dtype return instance + @property + def dtype(self) -> torch.dtype: + # Attribute access of custom tensors goes through an + # expensive Pyobject lookup. Since dtype for a tensor is never + # change after creation, we cache it in a member variable. + return self._dtype + + # @property + # def requires_grad(self) -> bool: + # # Attribute access of custom tensors goes through an + # # expensive Pyobject lookup. Since requires_grad is set during + # # initialization and may be updated, we cache it in a member variable. + # return self._requires_grad + + # @requires_grad.setter + # def requires_grad(self, value: bool) -> None: + # # Update the cached value + # self._requires_grad = value + # # Call parent class to ensure autograd engine is aware of the change + # torch.Tensor.requires_grad.fset(self, value) + + # def requires_grad_(self, requires_grad: bool = True) -> QuantizedTensor: + # # pylint: disable=missing-function-docstring + # # Update the cached value + # self._requires_grad = requires_grad + # # Call parent class method to ensure autograd engine is aware + # super().requires_grad_(requires_grad) + # return self + def dequantize(self, *, dtype: Optional[torch.dtype] = None) -> torch.Tensor: """Convert quantized data to standard PyTorch tensor""" raise NotImplementedError( diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 43cbdcf9e6..4e04708898 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -910,6 +910,14 @@ def fsdp_post_all_gather( ) return out, all_gather_outputs + @property + def shape(self): + return self._data.shape if self._data is not None else self._transpose.shape + + @property + def is_cuda(self): + return self._data.is_cuda if self._data is not None else self._transpose.is_cuda + @classmethod def _make_in_reduce_ex( cls,