Skip to content

Commit 1ec1448

Browse files
suyogguptalucaslie
andcommitted
Fix cudagraphs, add rms norm pattern matcher (#87)
* fix overlap scheduler in AD Signed-off-by: Suyog Gupta <[email protected]> * cleanups Signed-off-by: Suyog Gupta <[email protected]> * fix nest sequences Signed-off-by: Suyog Gupta <[email protected]> * nits * avoid hardcoding max beam width Signed-off-by: Suyog Gupta <[email protected]> * avoid hardcoding max beam width Signed-off-by: Suyog Gupta <[email protected]> * cudagraph fixes + rms norm Signed-off-by: Suyog Gupta <[email protected]> * fix test Signed-off-by: Suyog Gupta <[email protected]> * revert ad_executor changes Signed-off-by: Suyog Gupta <[email protected]> * Review comments + make sure num_pages >= max batch size * wrapping reviewer feedback and open items Signed-off-by: Lucas Liebenwein <[email protected]> --------- Signed-off-by: Suyog Gupta <[email protected]> Signed-off-by: Lucas Liebenwein <[email protected]> Co-authored-by: Lucas Liebenwein <[email protected]>
1 parent c2d2065 commit 1ec1448

File tree

12 files changed

+305
-26
lines changed

12 files changed

+305
-26
lines changed

tensorrt_llm/_torch/auto_deploy/compile/backends/torch_cudagraph.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,11 @@ def __init__(
3535
self._out_buffer_flat: List[torch.Tensor] = None
3636
self._args_hash: Optional[Tuple[int, ...]] = None
3737
self.cuda_graph_batch_sizes = (
38-
cuda_graph_batch_sizes
38+
sorted(cuda_graph_batch_sizes, reverse=True)
3939
if cuda_graph_batch_sizes is not None
4040
else self._get_graph_batch_sizes(self.max_batch_size)
4141
)
42+
self._cuda_graph_mem_pool = None
4243

4344
def _get_hash(self, flat_args: List[Any]) -> Tuple[int, ...]:
4445
return tuple(hash(a) for a in flat_args)
@@ -64,7 +65,7 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
6465
# capture graph now
6566
torch.cuda.synchronize()
6667
graph = torch.cuda.CUDAGraph()
67-
with torch.cuda.graph(graph):
68+
with torch.cuda.graph(graph, pool=self._cuda_graph_mem_pool):
6869
# compute output
6970
out = self.model(*args, **kwargs)
7071
# write out into output buffer up to out batch size
@@ -73,7 +74,7 @@ def _capture_one_graph(self, *args, **kwargs) -> torch.cuda.CUDAGraph:
7374
for o_buffer, o in zip(self._out_buffer_flat, out_flat):
7475
o_buffer[: o.shape[0]] = o
7576
torch.cuda.synchronize()
76-
77+
self._cuda_graph_mem_pool = self._cuda_graph_mem_pool or graph.pool()
7778
return graph
7879

7980
@staticmethod
@@ -88,7 +89,7 @@ def _get_graph_batch_sizes(
8889
batch_sizes.update(range(multiplier, max_bs + 1, multiplier))
8990

9091
# return as sorted list
91-
return sorted(batch_sizes)
92+
return sorted(batch_sizes, reverse=True)
9293

9394
def capture_graph(self, *args, **kwargs):
9495
"""Capture and pre-fetch the graph for variable batch size."""
@@ -118,6 +119,7 @@ def capture_graph(self, *args, **kwargs):
118119

119120
# capture output once with max batch size to capture output buffers
120121
with CudaGraphWarmUpPhase():
122+
ad_logger.info(f"Warm up with {self.max_batch_size=} before graph capture")
121123
out = self.model(*args, **kwargs)
122124
self._out_buffer_flat, out_spec = tree_flatten(out)
123125
assert out_spec == self._out_spec, "Output spec mismatch."

tensorrt_llm/_torch/auto_deploy/custom_ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from .linear import *
88
from .mla import *
99
from .quant import *
10+
from .rms_norm import *
1011
from .torch_attention import *
1112
from .torch_backend_attention import *
1213
from .torch_moe import *

tensorrt_llm/_torch/auto_deploy/custom_ops/attention_interface.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,20 @@ def __post_init__(self):
117117
# if the provided max_num_tokens is less than the max_batch_size * max_seq_len,
118118
# we use the provided max_num_tokens to calculate the number of pages
119119
total_tokens = min(self.max_num_tokens, self.max_batch_size * max_seq_len_adjusted)
120-
self._num_pages = (total_tokens) // self.page_size + (total_tokens % self.page_size > 0)
120+
# Num pages can not be less than max_batch_size.
121+
self._num_pages = max(
122+
self.max_batch_size,
123+
(total_tokens) // self.page_size + (total_tokens % self.page_size > 0),
124+
)
121125
self.input_ids = torch.ones(self.max_batch_size, 1, dtype=torch.int)
122126
self.position_ids = torch.zeros(self.max_batch_size, 1, dtype=torch.long)
123127
self.seq_len = torch.empty(self.max_batch_size, dtype=torch.int)
124128
self.input_pos = torch.empty_like(self.seq_len)
125129
self.cache_loc = torch.empty(self.num_pages, dtype=torch.int)
126130
self.pages_per_seq = torch.empty_like(self.seq_len)
127-
131+
assert self.num_pages >= self.max_batch_size, (
132+
"num_pages must be greater than max_batch_size"
133+
)
128134
# dynamic shape descriptors for tensor args
129135
self._dynamic_shapes: Optional[Tuple[Dict[str, Dim]]] = None
130136

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
"""Custom operator for FlashInfer and Triton RMSNorm implementation."""
2+
3+
import flashinfer
4+
import torch
5+
6+
from .triton_kernels.rms_norm import rms_norm
7+
8+
9+
@torch.library.custom_op("auto_deploy::flashinfer_rms_norm", mutates_args=())
10+
def flashinfer_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
11+
"""Custom operator for FlashInfer RMSNorm implementation.
12+
13+
Args:
14+
input: Input tensor to normalize.
15+
weight: Scaling weights for the normalized output.
16+
eps: Small constant for numerical stability.
17+
18+
Returns:
19+
Normalized and scaled tensor using FlashInfer implementation.
20+
"""
21+
# Flashinfer rmsnorm expects a 2D input
22+
input_flat = input.reshape(-1, input.shape[-1])
23+
rmsnorm_flat = flashinfer.norm.rmsnorm(input_flat, weight, eps)
24+
return rmsnorm_flat.reshape(input.shape)
25+
26+
27+
@flashinfer_rmsnorm.register_fake
28+
def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
29+
"""Fake implementation for the custom operator during tracing.
30+
31+
Args:
32+
input: Input tensor to normalize.
33+
weight: Scaling weights for the normalized output.
34+
eps: Small constant for numerical stability.
35+
36+
Returns:
37+
Empty tensor with same shape as input.
38+
"""
39+
return torch.empty_like(input)
40+
41+
42+
@torch.library.custom_op("auto_deploy::triton_rms_norm", mutates_args=())
43+
def triton_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
44+
"""Custom operator for Triton RMSNorm implementation.
45+
46+
Args:
47+
input: Input tensor to normalize.
48+
weight: Scaling weights for the normalized output.
49+
eps: Small constant for numerical stability.
50+
51+
Returns:
52+
Normalized and scaled tensor using Triton implementation.
53+
"""
54+
return rms_norm(input, weight, eps)
55+
56+
57+
@triton_rmsnorm.register_fake
58+
def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
59+
"""Fake implementation for the custom operator during tracing."""
60+
return torch.empty_like(input)
61+
62+
63+
@torch.library.custom_op("auto_deploy::torch_rmsnorm", mutates_args=())
64+
def torch_rmsnorm(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
65+
"""Custom operator for Torch RMSNorm implementation.
66+
67+
Args:
68+
input: Input tensor to normalize.
69+
weight: Scaling weights for the normalized output.
70+
eps: Small constant for numerical stability.
71+
"""
72+
input_dtype = input.dtype
73+
input = input.to(torch.float32)
74+
variance = input.pow(2).mean(-1, keepdim=True)
75+
input = input * torch.rsqrt(variance + eps)
76+
return weight * input.to(input_dtype)
77+
78+
79+
@torch_rmsnorm.register_fake
80+
def _(input: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
81+
"""Fake implementation for the custom operator during tracing."""
82+
return torch.empty_like(input)

tensorrt_llm/_torch/auto_deploy/export/interface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ def _apply_patches(remaining_patches):
242242
yield from _apply_patches(remaining_patches[1:])
243243

244244
# log applied patches
245-
ad_logger.info(
245+
ad_logger.debug(
246246
f"applying export patches: {', '.join([patch.get_patch_key() for patch in patches])}"
247247
)
248248

tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .kvcache import *
1010
from .quantization import *
1111
from .quantize_moe import *
12+
from .rms_norm import *
1213
from .rope import *
1314
from .sharding import *
1415

tensorrt_llm/_torch/auto_deploy/transformations/library/kvcache.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -143,8 +143,13 @@ def resize_kv_cache(
143143
144144
free_mem_ratio specifies the fraction of available memory to occupy.
145145
"""
146-
free_mem, total_mem = torch.cuda.mem_get_info()
147-
ad_logger.info(f"Free memory: {free_mem}, Total memory: {total_mem}")
146+
147+
def _get_mem_info_in_mb():
148+
free_mem, total_mem = torch.cuda.mem_get_info()
149+
return free_mem // 1024**2, total_mem // 1024**2
150+
151+
free_mem, total_mem = _get_mem_info_in_mb()
152+
ad_logger.info(f"Free memory (MB): {free_mem}, Total memory (MB): {total_mem}")
148153
current_cache_size = cm.current_cache_size_bytes()
149154
current_num_pages = cm.info.num_pages
150155
ad_logger.info(
@@ -158,14 +163,16 @@ def resize_kv_cache(
158163
try:
159164
# Let's run a forward pass to get the memory usage
160165
cm.info._set_max_num_tokens_sample()
161-
free_mem_pre, _ = torch.cuda.mem_get_info()
162-
ad_logger.info(f"Free memory before forward pass: {free_mem_pre}")
166+
free_mem_pre, _ = _get_mem_info_in_mb()
167+
ad_logger.info(f"Free memory before forward pass (MB): {free_mem_pre}")
168+
163169
egm(*cm.args)
164-
free_mem_post, _ = torch.cuda.mem_get_info()
165-
ad_logger.info(f"Free memory after forward pass: {free_mem_post}")
170+
171+
free_mem_post, _ = _get_mem_info_in_mb()
172+
ad_logger.info(f"Free memory after forward pass (MB): {free_mem_post}")
166173

167174
memory_for_forward_pass = free_mem_pre - free_mem_post
168-
ad_logger.info(f"Memory for forward pass: {memory_for_forward_pass}")
175+
ad_logger.info(f"Memory for forward pass (MB): {memory_for_forward_pass}")
169176

170177
new_cache_size = free_mem_post * free_mem_ratio + current_cache_size
171178
new_num_pages = int(new_cache_size // (current_cache_size // current_num_pages))
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
"""Graph transform to optimize RMSNorm execution using FlashInfer."""
2+
3+
from functools import partial
4+
5+
import torch
6+
from torch.fx import GraphModule
7+
8+
from ...utils.logger import ad_logger
9+
10+
# It is important to import ADPatternMatcherPass from pattern_matcher.py, not from torch._inductor.pattern_matcher
11+
from ...utils.pattern_matcher import ADPatternMatcherPass, register_ad_pattern
12+
from .._graph import canonicalize_graph
13+
14+
_BACKEND_OPS = {
15+
"flashinfer": torch.ops.auto_deploy.flashinfer_rms_norm,
16+
"triton": torch.ops.auto_deploy.triton_rms_norm,
17+
"torch": torch.ops.auto_deploy.torch_rmsnorm,
18+
}
19+
20+
21+
def _rms_norm_pattern(data: torch.Tensor, weight: torch.Tensor, eps: float) -> torch.Tensor:
22+
"""Implements the RMSNorm pattern for pattern matching.
23+
24+
Args:
25+
data: Input tensor to normalize.
26+
weight: Scaling weights for the normalized output.
27+
eps: Small constant for numerical stability.
28+
29+
Returns:
30+
Normalized and scaled tensor.
31+
"""
32+
input_dtype = data.dtype
33+
data = data.to(torch.float32)
34+
variance = data.pow(2).mean(-1, keepdim=True)
35+
data = data * torch.rsqrt(variance + eps)
36+
return weight * data.to(input_dtype)
37+
38+
39+
def _rms_norm_replacement(
40+
data: torch.Tensor, weight: torch.Tensor, eps: float, backend: str
41+
) -> torch.Tensor:
42+
"""Backend-specific rms_norm implementation.
43+
44+
Args:
45+
data: Input tensor to normalize.
46+
weight: Scaling weights for the normalized output.
47+
eps: Small constant for numerical stability.
48+
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
49+
50+
Returns:
51+
Normalized and scaled tensor using the specified backend implementation.
52+
"""
53+
54+
assert backend.lower() in _BACKEND_OPS, (
55+
f"Invalid {backend=}; must be one of {list(_BACKEND_OPS)}"
56+
)
57+
return _BACKEND_OPS[backend.lower()](data, weight, eps)
58+
59+
60+
def fuse_rmsnorm(gm: GraphModule, backend: str = "triton") -> None:
61+
"""Matches and replaces RMSNorm patterns in the graph with FlashInfer or Triton implementation.
62+
63+
This function sets up pattern matching to identify RMSNorm operations in the graph
64+
and replaces them with optimized implementations. It uses dummy tensors to register
65+
the pattern matching rules.
66+
67+
Args:
68+
gm: Input graph module to transform.
69+
backend: Backend to use for RMSNorm computation ("flashinfer" or "triton").
70+
71+
Returns:
72+
Transformed graph module with optimized RMSNorm operations.
73+
"""
74+
if backend.lower() not in _BACKEND_OPS:
75+
raise ValueError(f"Invalid backend, must be one of {list(_BACKEND_OPS)}, got {backend}")
76+
ad_logger.info(f"Starting RMSNorm pattern matching with backend: {backend}")
77+
78+
graph = gm.graph
79+
patterns = ADPatternMatcherPass()
80+
81+
# Create dummy tensors for pattern matching
82+
bs = 2
83+
hidden_size = 512
84+
85+
def dummy_args(input_dtype: torch.dtype, weight_dtype: torch.dtype, eps: float = 1e-6):
86+
return [
87+
torch.randn(bs, hidden_size, device="cuda", dtype=input_dtype),
88+
torch.randn(hidden_size, device="cuda", dtype=weight_dtype),
89+
eps,
90+
]
91+
92+
# Define configurations for different data types
93+
configs = [
94+
(torch.bfloat16, torch.bfloat16),
95+
(torch.float16, torch.float16),
96+
(torch.float32, torch.float32),
97+
]
98+
99+
# Register patterns for each configuration
100+
for input_dtype, weight_dtype in configs:
101+
register_ad_pattern(
102+
search_fn=_rms_norm_pattern,
103+
replace_fn=partial(_rms_norm_replacement, backend=backend),
104+
patterns=patterns,
105+
dummy_args=dummy_args(input_dtype, weight_dtype),
106+
op_ignore_types={},
107+
scalar_workaround={"eps": 1e-6},
108+
)
109+
110+
cnt = patterns.apply(graph)
111+
ad_logger.info(f"RMSNorm pattern count: {cnt}")
112+
canonicalize_graph(gm)
113+
ad_logger.debug("RMSNorm pattern matching completed.")

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
ep_shard,
2323
fuse_allreduce_residual_rmsnorm,
2424
fuse_collectives,
25+
fuse_rmsnorm,
2526
insert_cached_attention,
2627
match_attention_layout,
2728
match_causal_attn_mask,
@@ -163,6 +164,10 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
163164
# check if we can fuse collectives
164165
fuse_collectives(egm)
165166

167+
# TODO (lucaslie): add backend selection as part of configurable inference optimizers
168+
# check if we can fuse rmsnorm
169+
fuse_rmsnorm(egm, "flashinfer")
170+
166171
# visualize the final graph
167172
if self.ad_config.visualize:
168173
try:

tensorrt_llm/bench/benchmark/throughput.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -388,6 +388,9 @@ def throughput_command(
388388
logger.warning(
389389
"Ignore extended_runtime_perf_knob_config for _autodeploy backend."
390390
)
391+
kwargs["world_size"] = kwargs.pop("tensor_parallel_size", None)
392+
kwargs.pop("pipeline_parallel_size", None)
393+
391394
llm = AutoDeployLLM(**kwargs)
392395
else:
393396
llm = LLM(**kwargs)

0 commit comments

Comments
 (0)