@@ -968,6 +968,20 @@ def _check_cudnn_fp4_availability():
968
968
)
969
969
970
970
971
+ def _is_cublas_fp4_available_in_cudnn ():
972
+ """Check if cuBLAS backend for FP4 GEMM is available in cuDNN."""
973
+ _check_cudnn_availability ()
974
+
975
+ # Check cuDNN backend version for FP4 support (requires cudnn_version == 9.11.1 or cudnn_version >= 9.13)
976
+ backend_version = cudnn .backend_version ()
977
+ CUDNN_VERSION_9_11_1 = 91101
978
+ CUDNN_VERSION_9_13_0 = 91300
979
+ return (
980
+ backend_version == CUDNN_VERSION_9_11_1
981
+ or backend_version >= CUDNN_VERSION_9_13_0
982
+ )
983
+
984
+
971
985
def _get_native_fp4_dtype ():
972
986
"""get native fp4 datatype if supported in the torch, otherwise return uint8."""
973
987
if hasattr (torch , "float4_e2m1fn_x2" ):
@@ -1084,8 +1098,11 @@ def build_cudnn_gemm_block_scale_dequantize_graph(
1084
1098
graph .validate ()
1085
1099
graph .build_operation_graph ()
1086
1100
graph .create_execution_plans ([cudnn .heur_mode .A , cudnn .heur_mode .B ])
1087
- # WAR: the alpha (contains the global scale) is not supported by the cuBLAS backend, need to deselect it.
1088
- graph .deselect_engines (["eng0" ])
1101
+
1102
+ # WAR: The alpha (contains the global scale) is not supported by the cuBLAS backend (eng0)
1103
+ # in older cuDNN versions, so we deselect it.
1104
+ if not _is_cublas_fp4_available_in_cudnn ():
1105
+ graph .deselect_engines (["eng0" ])
1089
1106
graph .check_support ()
1090
1107
graph .build_plans ()
1091
1108
0 commit comments