Skip to content

Commit d6f95a4

Browse files
authored
[None][feat] AutoDeploy: Perf optimization for Attention and rmsnorm (#9719)
Signed-off-by: Chenghao Zhang <211069071+nvchenghaoz@users.noreply.github.com>
1 parent c7b5e3e commit d6f95a4

File tree

2 files changed

+14
-8
lines changed

2 files changed

+14
-8
lines changed

tensorrt_llm/_torch/auto_deploy/custom_ops/flashinfer_attention.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from torch._subclasses import FakeTensor
88
from torch.fx import Node
99

10+
from ...flashinfer_utils import get_env_enable_pdl
1011
from ..utils.cuda_graph import cuda_graph_state
1112
from ..utils.logger import ad_logger
1213
from ..utils.node_utils import extract_op_args
@@ -256,9 +257,9 @@ def flashinfer_mha_with_cache(
256257
q_shape_og = q.shape
257258
b, s = q_shape_og[:2]
258259

259-
q = q.contiguous().view(b * s, -1, head_dim)
260-
k = k.contiguous().view(b * s, -1, head_dim)
261-
v = v.contiguous().view(b * s, -1, head_dim)
260+
q = q.reshape(b * s, -1, head_dim)
261+
k = k.reshape(b * s, -1, head_dim)
262+
v = v.reshape(b * s, -1, head_dim)
262263

263264
n_heads = q.shape[1]
264265
n_kv_heads = k.shape[1]
@@ -275,11 +276,12 @@ def flashinfer_mha_with_cache(
275276
sm_scale=scale,
276277
)
277278

278-
# Assuming k_scale = v_scale = 1.0, we just have to cast k and v to fp8 before appending to kv cache
279+
# Assuming k_scale = v_scale = 1.0
279280
k_scale, v_scale = 1.0, 1.0
281+
# k = (k / k_scale).to(torch.float8_e4m3fn) if k_scale != 1.0, same for v
280282
if k_cache.dtype == torch.float8_e4m3fn:
281-
k = (k / k_scale).to(torch.float8_e4m3fn)
282-
v = (v / v_scale).to(torch.float8_e4m3fn)
283+
k = k.to(torch.float8_e4m3fn)
284+
v = v.to(torch.float8_e4m3fn)
283285

284286
flashinfer.page.append_paged_kv_cache(
285287
k,
@@ -300,7 +302,10 @@ def flashinfer_mha_with_cache(
300302
paged_kv_last_page_len,
301303
pp,
302304
)
303-
y = wrapper.run(q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale)
305+
306+
y = wrapper.run(
307+
q, (k_cache, v_cache), k_scale=k_scale, v_scale=v_scale, enable_pdl=get_env_enable_pdl()
308+
)
304309

305310
return y.view(q_shape_og) # [b,s,n*h_d] or [b,s, n, h_d]
306311

tensorrt_llm/_torch/auto_deploy/custom_ops/rms_norm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import flashinfer
44
import torch
55

6+
from ...flashinfer_utils import get_env_enable_pdl
67
from ...modules.mamba.layernorm_gated import _layer_norm_fwd
78
from .triton_kernels.rms_norm import rms_norm
89

@@ -21,7 +22,7 @@ def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) ->
2122
"""
2223
# Flashinfer rmsnorm expects a 2D input
2324
input_flat = input.reshape(-1, input.shape[-1])
24-
rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps)
25+
rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps, enable_pdl=get_env_enable_pdl())
2526
return rmsnorm_flat.reshape(input.shape)
2627

2728

0 commit comments

Comments
 (0)