77from torch ._subclasses import FakeTensor
88from torch .fx import Node
99
10+ from ...flashinfer_utils import get_env_enable_pdl
1011from ..utils .cuda_graph import cuda_graph_state
1112from ..utils .logger import ad_logger
1213from ..utils .node_utils import extract_op_args
@@ -256,9 +257,9 @@ def flashinfer_mha_with_cache(
256257 q_shape_og = q .shape
257258 b , s = q_shape_og [:2 ]
258259
259- q = q .contiguous (). view (b * s , - 1 , head_dim )
260- k = k .contiguous (). view (b * s , - 1 , head_dim )
261- v = v .contiguous (). view (b * s , - 1 , head_dim )
260+ q = q .reshape (b * s , - 1 , head_dim )
261+ k = k .reshape (b * s , - 1 , head_dim )
262+ v = v .reshape (b * s , - 1 , head_dim )
262263
263264 n_heads = q .shape [1 ]
264265 n_kv_heads = k .shape [1 ]
@@ -275,11 +276,12 @@ def flashinfer_mha_with_cache(
275276 sm_scale = scale ,
276277 )
277278
278- # Assuming k_scale = v_scale = 1.0, we just have to cast k and v to fp8 before appending to kv cache
279+ # Assuming k_scale = v_scale = 1.0
279280 k_scale , v_scale = 1.0 , 1.0
281+ # k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
280282 if k_cache .dtype == torch .float8_e4m3fn :
281- k = ( k / k_scale ) .to (torch .float8_e4m3fn )
282- v = ( v / v_scale ) .to (torch .float8_e4m3fn )
283+ k = k .to (torch .float8_e4m3fn )
284+ v = v .to (torch .float8_e4m3fn )
283285
284286 flashinfer .page .append_paged_kv_cache (
285287 k ,
@@ -300,7 +302,10 @@ def flashinfer_mha_with_cache(
300302 paged_kv_last_page_len ,
301303 pp ,
302304 )
303- y = wrapper .run (q , (k_cache , v_cache ), k_scale = k_scale , v_scale = v_scale )
305+
306+ y = wrapper .run (
307+ q , (k_cache , v_cache ), k_scale = k_scale , v_scale = v_scale , enable_pdl = get_env_enable_pdl ()
308+ )
304309
305310 return y .view (q_shape_og ) # [b,s,n*h_d] or [b,s, n, h_d]
306311
0 commit comments