Skip to content

Commit b47abb4

Browse files
h-guo18Fridah-nv
andauthored
Move quantization &quant_moe to new inf optimizer (#112)
* refactor: move quantization and quant_moe to new inf optimizer Signed-off-by: haoguo <[email protected]> * refactor: use quant_config from factory instead of new config type Signed-off-by: haoguo <[email protected]> * refactor: del old files; update default.yaml Signed-off-by: haoguo <[email protected]> * move helper class FakeFacotry to _graph_test_helpers.py Signed-off-by: haoguo <[email protected]> * polish: remove unreachable branch in quantization.py Co-authored-by: Fridah-nv <[email protected]> Signed-off-by: h-guo18 <[email protected]> * style: run pre-commit Signed-off-by: haoguo <[email protected]> * fix to fetch hf_quant_config from fetched dir Signed-off-by: Frida Hou <[email protected]> --------- Signed-off-by: haoguo <[email protected]> Signed-off-by: h-guo18 <[email protected]> Signed-off-by: Frida Hou <[email protected]> Co-authored-by: Fridah-nv <[email protected]>
1 parent 7810dd0 commit b47abb4

File tree

9 files changed

+207
-119
lines changed

9 files changed

+207
-119
lines changed

tensorrt_llm/_torch/auto_deploy/config/default.yaml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,3 +19,7 @@ transforms:
1919
stage: post_export
2020
cleanup_input_constraints:
2121
stage: post_export
22+
quantize:
23+
stage: pattern_matcher
24+
quantize_moe:
25+
stage: pattern_matcher

tensorrt_llm/_torch/auto_deploy/models/hf.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -305,7 +305,7 @@ def _prefetch_checkpoint(self, model_name_or_path: str, skip_prefetch_weights: b
305305
# at this point it should be a directory (either the original one or the download dir)
306306
assert os.path.isdir(fetched_dir), f"Checkpoint path {fetched_dir} is not a directory."
307307

308-
self._load_quantization_config()
308+
self._load_quantization_config(fetched_dir)
309309

310310
return fetched_dir
311311

@@ -323,13 +323,13 @@ def _load_checkpoint(self, model: nn.Module, device: DeviceLikeType):
323323
# model-transformed weights,leading to unexpected key mismatches or format issues.
324324
load_checkpoint_in_model(model, checkpoint=ckpt_file, full_state_dict=False)
325325

326-
def _load_quantization_config(self):
326+
def _load_quantization_config(self, fetched_dir: str):
327327
"""Load the quantization config from the model directory if not done already."""
328328
if self._quant_config is not None:
329329
return
330330

331331
assert self.model
332-
hf_quant_config_file = os.path.join(self.model, "hf_quant_config.json")
332+
hf_quant_config_file = os.path.join(fetched_dir, "hf_quant_config.json")
333333
if os.path.exists(hf_quant_config_file):
334334
with open(hf_quant_config_file, "r") as file:
335335
quantization_config = json.load(file)

tensorrt_llm/_torch/auto_deploy/transformations/library/quantization.py renamed to tensorrt_llm/_torch/auto_deploy/transform/library/quantization.py

Lines changed: 64 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
from collections import defaultdict
22
from functools import partial
3-
from typing import Any, Dict
3+
from typing import Dict, Tuple
44

55
import torch.nn as nn
66
from torch.fx import GraphModule, Node
77

8-
from ...utils.logger import ad_logger
8+
from ...models.factory import ModelFactory
9+
from ...shim.interface import CachedSequenceInterface
910
from ...utils.node_utils import (
1011
extract_param_names_from_lin_node,
1112
get_quantization_params_from_linear_node,
@@ -20,7 +21,7 @@
2021
remove_output_quantizers,
2122
should_skip_quantization,
2223
)
23-
from .._graph import canonicalize_graph
24+
from ..interface import BaseTransform, TransformInfo, TransformRegistry
2425

2526

2627
def _insert_quantized_linear(
@@ -138,12 +139,8 @@ def get_scale_name(scale_name):
138139
scale_target_module = gm # Register in root module
139140
scale_name_prefix = ""
140141

141-
ad_logger.info(f"Quantized BMM with dynamic weight tensor for node {node}")
142142
else:
143143
# If we can't determine the shape, skip quantization
144-
ad_logger.warning(
145-
f"BMM weight is dynamic tensor without shape metadata, skipping quantization for node {node}"
146-
)
147144
return
148145

149146
# Common logic for both parameter and dynamic tensor cases
@@ -169,53 +166,70 @@ def get_scale_name(scale_name):
169166
node.args = (*node.args, *scale_values)
170167

171168

172-
def quantize(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
173-
"""Quantize the GraphModule and replace linear with quantized linear."""
174-
# extract info from quant_config
175-
is_quant_graph = is_quantized_graph(gm)
176-
quant_algo = quant_config.get("quant_algo")
177-
excluded_patterns = quant_config.get("exclude_modules", [])
178-
179-
# no quantization to do
180-
if not (is_quant_graph or quant_config):
181-
ad_logger.info("No quantization to do.")
182-
return
169+
@TransformRegistry.register("quantize")
170+
class Quantization(BaseTransform):
171+
"""Quantize the GraphModule and replace linear/BMM with quantized linear/BMM."""
183172

184-
# tracking quantized operations in the graph
185-
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
186-
for n in gm.graph.nodes:
187-
if should_skip_quantization(n, excluded_patterns):
188-
continue
189-
190-
# Process linear operations
191-
if is_linear_op(n, include_quantization=False):
192-
# get per-layer quantization format from the node
193-
quant_algo_n: str = (
194-
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
173+
def _apply(
174+
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
175+
) -> Tuple[GraphModule, TransformInfo]:
176+
# extract info from quant_config
177+
quant_config = factory.get_quant_config()
178+
if not quant_config:
179+
return gm, TransformInfo(
180+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
195181
)
196-
if not quant_algo_n:
197-
continue
198182

199-
# insert quantized linear node
200-
_insert_quantized_linear(gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph)
201-
quantized_nodes[quant_algo_n]["linear"] += 1
183+
is_quant_graph = is_quantized_graph(gm)
184+
quant_algo = quant_config.get("quant_algo")
185+
excluded_patterns = quant_config.get("exclude_modules", [])
186+
if not quant_algo:
187+
return gm, TransformInfo(
188+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
189+
)
202190

203-
# Process BMM operations
204-
elif is_bmm_op(n):
205-
if not quant_algo:
191+
# tracking quantized operations in the graph
192+
quantized_nodes: Dict[str, Dict[str, int]] = defaultdict(lambda: defaultdict(int))
193+
for n in gm.graph.nodes:
194+
if should_skip_quantization(n, excluded_patterns):
206195
continue
207196

208-
# insert quantized bmm node
209-
_insert_quantized_bmm(
210-
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
211-
)
212-
quantized_nodes[quant_algo]["bmm"] += 1
213-
214-
if is_quant_graph:
215-
remove_output_quantizers(gm)
197+
# Process linear operations
198+
if is_linear_op(n, include_quantization=False):
199+
# get per-layer quantization format from the node
200+
quant_algo_n: str = (
201+
get_quantization_from_linear_node(n) if is_quant_graph else quant_algo
202+
)
203+
if not quant_algo_n:
204+
continue
205+
206+
# insert quantized linear node
207+
_insert_quantized_linear(
208+
gm, n, QuantizationImpl.create(quant_algo_n), is_quant_graph
209+
)
210+
quantized_nodes[quant_algo_n]["linear"] += 1
211+
212+
# Process BMM operations
213+
elif is_bmm_op(n):
214+
if not quant_algo:
215+
continue
216+
217+
# insert quantized bmm node
218+
_insert_quantized_bmm(
219+
gm, n, QuantizationImpl.create(quant_algo, is_bmm=True), is_quant_graph
220+
)
221+
quantized_nodes[quant_algo]["bmm"] += 1
222+
223+
if is_quant_graph:
224+
remove_output_quantizers(gm)
225+
226+
num_matches = 0
227+
for quant_algo in quantized_nodes:
228+
for op_type, count in quantized_nodes[quant_algo].items():
229+
num_matches += count
230+
231+
info = TransformInfo(
232+
skipped=False, num_matches=num_matches, is_clean=False, has_valid_shapes=True
233+
)
216234

217-
canonicalize_graph(gm)
218-
for quant_algo in quantized_nodes:
219-
for op_type, count in quantized_nodes[quant_algo].items():
220-
ad_logger.info(f"Found {count} {quant_algo} quantized {op_type} nodes.")
221-
ad_logger.debug("After quantization: " + str(gm))
235+
return gm, info

tensorrt_llm/_torch/auto_deploy/transformations/library/quantize_moe.py renamed to tensorrt_llm/_torch/auto_deploy/transform/library/quantize_moe.py

Lines changed: 52 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
from functools import partial
2-
from typing import Any, Callable, Dict, List, Tuple
2+
from typing import Callable, List, Tuple
33

44
import torch
55
import torch.nn as nn
66
from torch.fx import GraphModule, Node
77

8-
from ...utils.logger import ad_logger
8+
from ...models.factory import ModelFactory
9+
from ...shim.interface import CachedSequenceInterface
910
from ...utils.node_utils import is_op
1011
from ...utils.quantization_utils import QuantizationImpl, should_skip_quantization
11-
from .._graph import canonicalize_graph
12+
from ..interface import BaseTransform, TransformInfo, TransformRegistry
1213

1314
quantized_moe_op_map = {
1415
"FP8": torch.ops.auto_deploy.torch_quant_fp8_moe,
@@ -92,47 +93,10 @@ def collect_scales(index: int) -> Tuple[List[Node], List[Node], List[Node]]:
9293
quantized_op,
9394
args=tuple(args),
9495
)
95-
ad_logger.debug(f"Updating {node.name} args to {new_node.args}")
9696
node.replace_all_uses_with(new_node)
9797
gm.graph.erase_node(node)
9898

9999

100-
def quantize_moe(gm: GraphModule, quant_config: Dict[str, Any]) -> None:
101-
"""
102-
Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the
103-
quantized version using the quant_algo from quant_config.
104-
"""
105-
quant_algo = quant_config.get("quant_algo")
106-
if not quant_algo:
107-
ad_logger.info("No quantization to do.")
108-
return gm
109-
excluded_patterns = quant_config.get("exclude_modules", [])
110-
111-
quant_impl = QuantizationImpl.create(quant_algo)
112-
quantized_op = quantized_moe_op_map[quant_algo]
113-
114-
count = 0
115-
116-
for node in list(gm.graph.nodes):
117-
if is_op(node, torch.ops.auto_deploy.torch_moe):
118-
# Check that all expert weights should be quantized
119-
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
120-
if any(
121-
should_skip_quantization(n, excluded_patterns)
122-
for n in w1_names + w2_names + w3_names
123-
):
124-
continue
125-
_quantize_moe_node(gm, node, quant_impl, quantized_op)
126-
count += 1
127-
128-
if count == 0:
129-
return gm
130-
131-
gm = canonicalize_graph(gm)
132-
ad_logger.info(f"Found {count} {quant_algo} quantized {quantized_op} nodes.")
133-
return
134-
135-
136100
# TODO(Fridah-nv): robust handling similar to `extract_param_names_from_lin_node` or expand it
137101
def _extract_moe_weight_param_lists(moe_node: Node) -> Tuple[List[str], List[str], List[str]]:
138102
"""
@@ -165,3 +129,51 @@ def _unwrap_list(arg) -> List[str]:
165129
w3_names = _unwrap_list(w3_list)
166130

167131
return w1_names, w2_names, w3_names
132+
133+
134+
@TransformRegistry.register("quantize_moe")
135+
class QuantizeMOE(BaseTransform):
136+
"""
137+
Traverse gm, find every torch.ops.auto_deploy.torch_moe, and replace it with the
138+
quantized version using the quant_algo from quant_config.
139+
"""
140+
141+
def _apply(
142+
self, gm: GraphModule, cm: CachedSequenceInterface, factory: ModelFactory
143+
) -> Tuple[GraphModule, TransformInfo]:
144+
quant_config = factory.get_quant_config()
145+
quant_algo = quant_config.get("quant_algo") if quant_config else None
146+
147+
if not quant_config or not quant_algo:
148+
return gm, TransformInfo(
149+
skipped=True, num_matches=0, is_clean=True, has_valid_shapes=True
150+
)
151+
excluded_patterns = quant_config.get("exclude_modules", [])
152+
153+
quant_impl = QuantizationImpl.create(quant_algo)
154+
quantized_op = quantized_moe_op_map[quant_algo]
155+
156+
count = 0
157+
158+
for node in list(gm.graph.nodes):
159+
if is_op(node, torch.ops.auto_deploy.torch_moe):
160+
# Check that all expert weights should be quantized
161+
w1_names, w2_names, w3_names = _extract_moe_weight_param_lists(node)
162+
if any(
163+
should_skip_quantization(n, excluded_patterns)
164+
for n in w1_names + w2_names + w3_names
165+
):
166+
continue
167+
_quantize_moe_node(gm, node, quant_impl, quantized_op)
168+
count += 1
169+
170+
if count == 0:
171+
return gm, TransformInfo(
172+
skipped=False, num_matches=0, is_clean=True, has_valid_shapes=True
173+
)
174+
175+
info = TransformInfo(
176+
skipped=False, num_matches=count, is_clean=False, has_valid_shapes=False
177+
)
178+
179+
return gm, info

tensorrt_llm/_torch/auto_deploy/transformations/library/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,6 @@
66
from .fused_moe import *
77
from .fusion import *
88
from .kvcache import *
9-
from .quantization import *
10-
from .quantize_moe import *
119
from .rms_norm import *
1210
from .rope import *
1311
from .sharding import *

tensorrt_llm/_torch/auto_deploy/transformations/transform.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -32,8 +32,6 @@
3232
match_rope_layout,
3333
match_rope_pattern,
3434
optimize_rope,
35-
quantize,
36-
quantize_moe,
3735
resize_kv_cache,
3836
sharding_transform_executor,
3937
update_in_out_nodes,
@@ -70,9 +68,6 @@ def __call__(self, cm: CachedSequenceInterface) -> nn.Module:
7068
############################################################################################
7169
# RUN PATTERN MATCHER TRANSFORMATIONS TO STANDARDIZE GRAPH REPRESENTATION
7270
############################################################################################
73-
# quantization
74-
quantize(egm, self.factory.get_quant_config())
75-
quantize_moe(egm, self.factory.get_quant_config())
7671

7772
# Match MoE pattern
7873
match_moe_pattern(egm)

tests/unittest/_torch/auto_deploy/_utils_test/_graph_test_helpers.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,15 +10,32 @@
1010

1111
from tensorrt_llm._torch.auto_deploy.custom_ops.attention_interface import SequenceInfo
1212
from tensorrt_llm._torch.auto_deploy.export import torch_export_to_gm
13+
from tensorrt_llm._torch.auto_deploy.models.factory import ModelFactory
1314
from tensorrt_llm._torch.auto_deploy.transformations.library.sharding import ShardingTransformInfo
1415

1516

16-
class FakeFactory:
17-
def __init__(self, model: nn.Module):
18-
self.model = model
17+
class FakeFactory(ModelFactory):
18+
"""Dummy factory to pass cache_config for testing."""
1919

20-
def build_model(self, device: str) -> nn.Module:
21-
return self.model.to(device=device)
20+
def __init__(self, model=None, cache_config=None, quant_config=None):
21+
self._model = model
22+
self.cache_config = cache_config
23+
self.quant_config = quant_config
24+
25+
def build_model(self, device: str):
26+
return self._model.to(device=device) if self._model else None
27+
28+
def _build_model(self, device: str):
29+
return
30+
31+
def _load_checkpoint(self, model, device):
32+
return
33+
34+
def get_cache_config(self):
35+
return self.cache_config
36+
37+
def get_quant_config(self):
38+
return self.quant_config
2239

2340

2441
class SequenceEmbeddingInfo(SequenceInfo):

0 commit comments

Comments
 (0)