Skip to content

Commit 7ce3702

Browse files
authored
feature: enable cublas for fp4 gemm when cudnn == 9.11.1 or >= 9.13 (#1405)
1 parent 2f62643 commit 7ce3702

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

flashinfer/gemm.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -968,6 +968,20 @@ def _check_cudnn_fp4_availability():
968968
)
969969

970970

971+
def _is_cublas_fp4_available_in_cudnn():
972+
"""Check if cuBLAS backend for FP4 GEMM is available in cuDNN."""
973+
_check_cudnn_availability()
974+
975+
# Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13)
976+
backend_version = cudnn.backend_version()
977+
CUDNN_VERSION_9_11_1 = 91101
978+
CUDNN_VERSION_9_13_0 = 91300
979+
return (
980+
backend_version == CUDNN_VERSION_9_11_1
981+
or backend_version >= CUDNN_VERSION_9_13_0
982+
)
983+
984+
971985
def _get_native_fp4_dtype():
972986
"""get native fp4 datatype if supported in the torch, otherwise return uint8."""
973987
if hasattr(torch, "float4_e2m1fn_x2"):
@@ -1084,8 +1098,11 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
10841098
graph.validate()
10851099
graph.build_operation_graph()
10861100
graph.create_execution_plans([cudnn.heur_mode.A, cudnn.heur_mode.B])
1087-
# WAR: the alpha (contains the global scale) is not supported by the cuBLAS backend, need to deselect it.
1088-
graph.deselect_engines(["eng0"])
1101+
1102+
# WAR: The alpha (contains the global scale) is not supported by the cuBLAS backend (eng0)
1103+
# in older cuDNN versions, so we deselect it.
1104+
if not _is_cublas_fp4_available_in_cudnn():
1105+
graph.deselect_engines(["eng0"])
10891106
graph.check_support()
10901107
graph.build_plans()
10911108

0 commit comments

Comments
 (0)