@@ -45,30 +45,36 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
4545 return None
4646
4747
48- if not has_deep_gemm ():
49- _fp8_gemm_nt_impl : Callable [..., Any ] | None = None
50- _grouped_impl : Callable [..., Any ] | None = None
51- _grouped_masked_impl : Callable [..., Any ] | None = None
52- _per_block_cast_impl : Callable [..., Any ] | None = None
53- else :
54- _dg = importlib .import_module ("deep_gemm" ) # type: ignore
55-
56- _fp8_gemm_nt_impl = _resolve_symbol (
57- _dg ,
58- "fp8_gemm_nt" ,
59- "gemm_fp8_fp8_bf16_nt" ,
60- )
48+ _fp8_gemm_nt_impl : Callable [..., Any ] | None = None
49+ _grouped_impl : Callable [..., Any ] | None = None
50+ _grouped_masked_impl : Callable [..., Any ] | None = None
51+ _per_block_cast_impl : Callable [..., Any ] | None = None
52+
53+
54+ def _lazy_init () -> None :
55+ """Import deep_gemm and resolve symbols on first use."""
56+ global _fp8_gemm_nt_impl , _grouped_impl , _grouped_masked_impl , \
57+ _per_block_cast_impl
58+
59+ # fast path
60+ if (_fp8_gemm_nt_impl is not None or _grouped_impl is not None
61+ or _grouped_masked_impl is not None
62+ or _per_block_cast_impl is not None ):
63+ return
64+
65+ if not has_deep_gemm ():
66+ return
67+
68+ _dg = importlib .import_module ("deep_gemm" )
69+
70+ _fp8_gemm_nt_impl = _resolve_symbol (_dg , "fp8_gemm_nt" ,
71+ "gemm_fp8_fp8_bf16_nt" )
6172 _grouped_impl = _resolve_symbol (
62- _dg ,
63- "m_grouped_fp8_gemm_nt_contiguous" ,
64- "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous" ,
65- )
73+ _dg , "m_grouped_fp8_gemm_nt_contiguous" ,
74+ "m_grouped_gemm_fp8_fp8_bf16_nt_contiguous" )
6675 _grouped_masked_impl = _resolve_symbol (
67- _dg ,
68- "fp8_m_grouped_gemm_nt_masked" ,
69- "m_grouped_gemm_fp8_fp8_bf16_nt_masked" ,
70- )
71-
76+ _dg , "fp8_m_grouped_gemm_nt_masked" ,
77+ "m_grouped_gemm_fp8_fp8_bf16_nt_masked" )
7278 # Try to get per_token_cast_to_fp8 from DeepGEMM math utils.
7379 try :
7480 _math_mod = importlib .import_module (
@@ -80,24 +86,28 @@ def _resolve_symbol(module, new: str, old: str) -> Callable[..., Any] | None:
8086
8187
8288def fp8_gemm_nt (* args , ** kwargs ):
89+ _lazy_init ()
8390 if _fp8_gemm_nt_impl is None :
8491 return _missing (* args , ** kwargs )
8592 return _fp8_gemm_nt_impl (* args , ** kwargs )
8693
8794
8895def m_grouped_fp8_gemm_nt_contiguous (* args , ** kwargs ):
96+ _lazy_init ()
8997 if _grouped_impl is None :
9098 return _missing (* args , ** kwargs )
9199 return _grouped_impl (* args , ** kwargs )
92100
93101
94102def fp8_m_grouped_gemm_nt_masked (* args , ** kwargs ):
103+ _lazy_init ()
95104 if _grouped_masked_impl is None :
96105 return _missing (* args , ** kwargs )
97106 return _grouped_masked_impl (* args , ** kwargs )
98107
99108
100109def per_block_cast_to_fp8 (x , * args , ** kwargs ):
110+ _lazy_init ()
101111 if _per_block_cast_impl is not None and is_blackwell_deep_gemm_used ():
102112 return _per_block_cast_impl (x , use_ue8m0 = True )
103113 # TODO: refactor the `per_block_cast_to_fp8` from tests to vllm utils
0 commit comments