Skip to content

Commit 04b1927

Browse files
authored
[quantization] Fallback for complex models (#583)
1 parent ed91d96 commit 04b1927

File tree

2 files changed

+265
-4
lines changed

2 files changed

+265
-4
lines changed

tico/quantization/algorithm/gptq/quantizer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,8 @@ def forward(layer, *args, **kwargs):
110110
):
111111
self._first_layer_ref = model.model.layers[0]
112112
else:
113-
raise RuntimeError(
114-
"GPTQ Quantizer assumes the model has a nested structure like `model.model.layers`, commonly found in LLaMA and other Hugging Face transformer models."
113+
self._first_layer_ref = (
114+
model # let's treat it as a single layer (fallback)
115115
)
116116
else:
117117
# fallback if the model is not LLaMA-like; treat whole model as single layer
@@ -180,7 +180,10 @@ def convert(self, model):
180180

181181
# Identify layers
182182
if hasattr(model, "model"):
183-
target_layers = model.model.layers
183+
if hasattr(model.model, "layers"):
184+
target_layers = model.model.layers
185+
else:
186+
target_layers = [model]
184187
else:
185188
target_layers = [model]
186189

@@ -301,7 +304,8 @@ def _hook(_, inp, out):
301304
# This line ensures we always take the first element when it's a tuple.
302305
outs = outs[0] if isinstance(outs, tuple) else outs
303306
# Update inputs for next iteration.
304-
self.cache_args[0][batch_idx] = outs
307+
if len(self.cache_args) > 0:
308+
self.cache_args[0][batch_idx] = outs
305309

306310
if torch.cuda.is_available():
307311
torch.cuda.empty_cache()
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
# Copyright (c) 2026 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
import argparse
17+
18+
import torch
19+
from transformers import AutoProcessor
20+
21+
from tico.quantization import convert, prepare
22+
23+
from tico.quantization.algorithm.gptq.utils import SensitivityCalibrator
24+
from tico.quantization.config.gptq import GPTQConfig
25+
from tico.quantization.evaluation.vlm_eval_utils import get_calib_inputs
26+
from tico.quantization.wrapq.examples.quantize_qwen3_vl_with_gptq import (
27+
evaluate_model,
28+
print_eval_results,
29+
print_markdown_comparison,
30+
)
31+
32+
DTYPE_MAP = {
33+
"float32": torch.float32,
34+
# TODO Support more dtypes
35+
# "bfloat16": torch.bfloat16,
36+
# "float16": torch.float16,
37+
}
38+
39+
40+
def main():
41+
parser = argparse.ArgumentParser(
42+
description="GPTQ+PTQ pipeline (weight-only + activation)"
43+
)
44+
parser.add_argument(
45+
"--model", type=str, required=True, help="HF repo name or local path."
46+
)
47+
parser.add_argument(
48+
"--device",
49+
type=str,
50+
default="cuda" if torch.cuda.is_available() else "cpu",
51+
help="Device to run on (cuda|cpu|mps).",
52+
)
53+
parser.add_argument(
54+
"--dtype",
55+
choices=list(DTYPE_MAP.keys()),
56+
default="float32",
57+
help="Model dtype for load.",
58+
)
59+
parser.add_argument("--seed", type=int, default=42, help="Random seed.")
60+
parser.add_argument(
61+
"--trust-remote-code",
62+
action="store_true",
63+
help="Enable only if you trust the model repo code.",
64+
)
65+
parser.add_argument(
66+
"--hf-token",
67+
type=str,
68+
default=None,
69+
help="Optional HF token for gated/private repos.",
70+
)
71+
parser.add_argument(
72+
"--cache_dir",
73+
type=str,
74+
default=None,
75+
help="cache_dir for using model/datasets loading",
76+
)
77+
parser.add_argument(
78+
"--nsamples_for_qcalibration",
79+
type=int,
80+
default="128", # almost standard
81+
help="number of samples to be used in GPTQ/PTQ calibration",
82+
)
83+
parser.add_argument(
84+
"--nsamples_for_evaluation",
85+
type=int,
86+
default="50",
87+
help="number of samples to be used in equantized model valuation. -1 stands for the whole dataset",
88+
)
89+
parser.add_argument(
90+
"--calib_seq_len",
91+
type=int,
92+
default=2048,
93+
help=(
94+
"Maximum text sequence length for calibration inputs. "
95+
"If not set, processor default behavior is used."
96+
),
97+
)
98+
parser.add_argument(
99+
"--max_seq_len",
100+
type=int,
101+
default=2048,
102+
help=(
103+
"Maximum text sequence length for evaluation and export. "
104+
"If not set, processor default behavior is used."
105+
),
106+
)
107+
parser.add_argument(
108+
"--linear_weight_bits",
109+
type=int,
110+
default=4,
111+
help="Weight bit-width for GPTQ quantization.",
112+
)
113+
parser.add_argument(
114+
"--gptq_mse",
115+
type=str,
116+
default=None,
117+
choices=["mse", "smse"],
118+
help="Whether and how to use mse in GPTQ.",
119+
)
120+
parser.add_argument(
121+
"--eval_tasks",
122+
type=str,
123+
default=None,
124+
help="Tasks to evaluate, e.g. `vqav2,textvqa`.",
125+
)
126+
parser.add_argument(
127+
"--sensitivity_path",
128+
type=str,
129+
default=None,
130+
)
131+
132+
args = parser.parse_args()
133+
print(args)
134+
135+
# Basic setup
136+
torch.manual_seed(args.seed)
137+
device = torch.device(args.device)
138+
dtype = DTYPE_MAP[args.dtype]
139+
140+
print("=== Config ===")
141+
print(f"Model : {args.model}")
142+
print(f"Device : {device.type}")
143+
print(f"DType : {args.dtype}")
144+
print(f"Calib seq len : {args.calib_seq_len}")
145+
print(f"Max seq len : {args.max_seq_len}")
146+
print()
147+
148+
# -------------------------------------------------------------------------
149+
# Load model and processor
150+
# -------------------------------------------------------------------------
151+
print("Loading FP model …")
152+
153+
processor = AutoProcessor.from_pretrained(
154+
args.model, trust_remote_code=True, cache_dir=args.cache_dir
155+
)
156+
dev_map = "balanced" if args.device != "cpu" else "cpu"
157+
try:
158+
from transformers import AutoModelForImageTextToText
159+
160+
model = AutoModelForImageTextToText.from_pretrained(
161+
args.model,
162+
dtype=dtype,
163+
trust_remote_code=True,
164+
cache_dir=args.cache_dir,
165+
device_map=dev_map,
166+
)
167+
except:
168+
from transformers import AutoModelForVision2Seq
169+
170+
model = AutoModelForVision2Seq.from_pretrained(
171+
args.model,
172+
torch_dtype=dtype,
173+
trust_remote_code=True,
174+
cache_dir=args.cache_dir,
175+
device_map=dev_map,
176+
)
177+
178+
model.eval()
179+
if hasattr(model, "config") and hasattr(model.config, "use_cache"):
180+
model.config.use_cache = False
181+
if hasattr(model, "config") and hasattr(model.config, "text_config"):
182+
if hasattr(model.config.text_config, "use_cache"):
183+
model.config.text_config.use_cache = False
184+
185+
if args.calib_seq_len is not None:
186+
model.config.text_config.max_position_embeddings = min(
187+
model.config.text_config.max_position_embeddings, args.calib_seq_len
188+
)
189+
190+
if args.eval_tasks is not None:
191+
print("Evaluating original model")
192+
original_results = evaluate_model(
193+
model,
194+
processor,
195+
args.eval_tasks,
196+
args.device,
197+
args.nsamples_for_evaluation,
198+
max_seq_len=args.max_seq_len,
199+
)
200+
print_eval_results("Evaluating original model", original_results)
201+
for key in original_results:
202+
result = original_results[key]
203+
print(
204+
f"Original EM: {result[0]/result[1]:.4f} (dataset={key}, n={result[1]})"
205+
)
206+
207+
calib_inputs = get_calib_inputs(
208+
"vqav2", processor, n_samples=args.nsamples_for_qcalibration
209+
)
210+
211+
# -------------------------------------------------------------------------
212+
# Run GPTQ (weight-only) pass
213+
# -------------------------------------------------------------------------
214+
print("Applying GPTQ …")
215+
216+
sens = None
217+
if args.gptq_mse is not None and args.gptq_mse == "smse":
218+
if args.sensitivity_path is not None:
219+
sens = torch.load(args.sensitivity_path)
220+
else:
221+
calibrator = SensitivityCalibrator(model, calib_inputs)
222+
sens = calibrator.compute_sensitivity_info()
223+
224+
gptq_config = GPTQConfig(
225+
weight_bits=args.linear_weight_bits,
226+
perchannel=True,
227+
mse=args.gptq_mse,
228+
sensitivity=sens,
229+
)
230+
q_m = prepare(model, gptq_config, inplace=True)
231+
232+
with torch.no_grad():
233+
for inp in calib_inputs:
234+
for item in inp:
235+
inp[item] = inp[item].to(args.device)
236+
q_m(**inp)
237+
238+
q_m = convert(q_m, inplace=True)
239+
240+
# -------------------------------------------------------------------------
241+
# evaluate quantized model
242+
# -------------------------------------------------------------------------
243+
if args.eval_tasks is not None:
244+
quantized_results = evaluate_model(
245+
q_m,
246+
processor,
247+
args.eval_tasks,
248+
args.device,
249+
args.nsamples_for_evaluation,
250+
max_seq_len=args.max_seq_len,
251+
)
252+
print_eval_results("Evaluating quantized model", quantized_results)
253+
print_markdown_comparison(original_results, quantized_results)
254+
255+
256+
if __name__ == "__main__":
257+
main()

0 commit comments

Comments
 (0)