Skip to content

Commit 5201525

Browse files
committed
[quantization] Full quantization
This draft tries to get fully quantized model. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent e49ffdb commit 5201525

21 files changed

+855
-45
lines changed

tico/passes/decompose_fake_quantize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,27 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
123123
)
124124
node.replace_all_uses_with(dequnt, propagate_meta=True)
125125
modified = True
126+
127+
if node.target in [torch.ops.circle_custom.quantize_mx.default]:
128+
# tensor, scale, zero_p, quant_min, quant_max
129+
assert len(node.args) == 3
130+
_, elem_format, axis = node.args
131+
132+
with gm.graph.inserting_before(node):
133+
quant = create_node(
134+
g,
135+
torch.ops.circle_custom.quantize_float_to_mx.default,
136+
args=node.args,
137+
origin=node,
138+
)
139+
dequnt = create_node(
140+
g,
141+
torch.ops.circle_custom.dequantize_mx_to_float.default,
142+
args=(quant, *quant.args[1:]),
143+
kwargs=quant.kwargs,
144+
)
145+
node.replace_all_uses_with(dequnt, propagate_meta=True)
146+
modified = True
126147

127148
gm.graph.eliminate_dead_code()
128149
gm.graph.lint()

tico/quantization/algorithm/gptq/quantizer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def convert(self, model):
212212
for name in subset:
213213
gptq[name] = GPTQ(subset[name])
214214
gptq[name].quantizer.configure(
215-
bits=8, perchannel=True, sym=False, mse=False
215+
bits=4, perchannel=True, sym=False, mse=False
216216
)
217217

218218
# Hook to collect (inp, out) for GPTQ

tico/quantization/passes/fold_quant_ops.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,78 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
144144
)
145145
dq.replace_all_uses_with(op, propagate_meta=False)
146146
logger.debug(f"Removed redundant {dq.name}")
147+
148+
for dq in graph.nodes:
149+
if dq.op != "call_function":
150+
continue
151+
if (
152+
dq.target
153+
!= torch.ops.circle_custom.dequantize_mx_to_float.default
154+
):
155+
continue
156+
157+
dq_args = dq.args
158+
159+
q = dq_args[0]
160+
if q.target != torch.ops.circle_custom.quantize_float_to_mx.default:
161+
continue
162+
q_args = q.args
163+
op = q_args[0]
164+
165+
# Check if Q and DQ have same parameters
166+
if q_args[1] != dq_args[1]:
167+
continue
168+
if q_args[2] != dq_args[2]:
169+
continue
170+
171+
# ───────────────────────────────────────────
172+
# Case 1: op not yet quantized
173+
# ───────────────────────────────────────────
174+
if QPARAM_KEY not in op.meta:
175+
#TODO
176+
qparam = QuantParam()
177+
qparam.dtype = "mxint8"# q_args[1] #TODO
178+
qparam.quantized_dimension = q_args[2]
179+
op.meta[QPARAM_KEY] = qparam
180+
181+
dq.replace_all_uses_with(op, propagate_meta=False)
182+
183+
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
184+
# ───────────────────────────────────────────
185+
# Case 2: op already quantized
186+
# 2.1 same dtype → nothing to do
187+
# 2.2 diff dtype → leave Q in place
188+
# ───────────────────────────────────────────
189+
else:
190+
op_qparam: QuantParam = op.meta[QPARAM_KEY]
191+
qdq_dtype = "mxint8"#q_args[1] #TODO
192+
193+
if op_qparam.dtype != qdq_dtype:
194+
# Attach QPARAM to Q once
195+
if QPARAM_KEY not in q.meta:
196+
qparam = QuantParam()
197+
qparam.dtype = qdq_dtype
198+
qparam.quantized_dimension = q_args[2]
199+
q.meta[QPARAM_KEY] = qparam
200+
assert len(q.users) == 1, "Fix me unless"
147201

202+
dq.replace_all_uses_with(q, propagate_meta=False)
203+
logger.debug(f"{dq.name} is folded ({q.name} is left).")
204+
else:
205+
# Same dtype → the Quantize–Dequantize pair is redundant.
206+
assert not op_qparam.scale
207+
assert not op_qparam.zero_point
208+
assert (
209+
op_qparam.dtype
210+
and op_qparam.dtype == 'mxint8' #TODO
211+
)
212+
assert (
213+
op_qparam.quantized_dimension is not None
214+
and op_qparam.quantized_dimension == q_args[2]
215+
)
216+
dq.replace_all_uses_with(op, propagate_meta=False)
217+
logger.debug(f"Removed redundant {dq.name}")
218+
148219
graph.eliminate_dead_code()
149220
graph.lint()
150221
graph_module.recompile()

tico/quantization/wrapq/examples/quantize_linear.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
from tico.quantization.wrapq.mode import Mode
3737
from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear
3838
from tico.utils.utils import SuppressWarning
39+
from tico.quantization.wrapq.dtypes import DType
40+
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
41+
from tico.quantization.wrapq.observers.mx import MXObserver
42+
from tico.quantization.wrapq.qscheme import QScheme
3943

4044

4145
# -------------------------------------------------------------------------
@@ -62,7 +66,14 @@ def forward(self, x):
6266
# -------------------------------------------------------------------------
6367
# 1. Replace the Linear with QuantLinear wrapper
6468
# -------------------------------------------------------------------------
65-
model.fc = prepare(fp32_layer, PTQConfig()) # type: ignore[assignment]
69+
cfg = PTQConfig(
70+
default_dtype=DType.uint(8),
71+
default_qscheme=QScheme.PER_TENSOR_ASYMM,
72+
default_observer=MXObserver,#MinMaxObserver,
73+
overrides = {"weight": {"dtype": DType.uint(4), "observer": MinMaxObserver}},
74+
)
75+
76+
model.fc = prepare(fp32_layer,cfg) # type: ignore[assignment]
6677
qlayer = model.fc # alias for brevity
6778

6879
# -------------------------------------------------------------------------

tico/quantization/wrapq/examples/quantize_llama_attn.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
from tico.quantization.wrapq.mode import Mode
2525
from tico.quantization.wrapq.wrappers.llama.quant_attn import QuantLlamaAttention
2626
from tico.utils.utils import SuppressWarning
27+
from tico.quantization.wrapq.dtypes import DType
28+
from tico.quantization.wrapq.mode import Mode
29+
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
30+
from tico.quantization.wrapq.observers.mx import MXObserver
31+
from tico.quantization.wrapq.qscheme import QScheme
2732

2833
name = "Maykeye/TinyLLama-v0"
2934
model = AutoModelForCausalLM.from_pretrained(name)
@@ -33,7 +38,23 @@
3338
# 1. Replace layer-0’s MLP with QuantLlamaMLP
3439
# -------------------------------------------------------------------------
3540
orig_attn = model.model.layers[0].self_attn
36-
model.model.layers[0].self_attn = prepare(orig_attn, PTQConfig())
41+
cfg = PTQConfig(
42+
default_dtype=DType.int(16),#DType.uint(8),
43+
default_qscheme=QScheme.PER_TENSOR_SYMM,#QScheme.PER_TENSOR_ASYMM,
44+
default_observer=MXObserver,#MinMaxObserver,
45+
overrides={
46+
"q_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver},
47+
#"act_out": {"dtype": DType.int(16), "observer":MinMaxObserver}
48+
},
49+
"k_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
50+
"v_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
51+
"o_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
52+
"scale": {"observer":MinMaxObserver},
53+
#"softmax": {"observer":MinMaxObserver},
54+
},
55+
)
56+
57+
model.model.layers[0].self_attn = prepare(orig_attn, cfg)
3758
model.eval()
3859

3960
attn_q = model.model.layers[0].self_attn # quant wrapper

tico/quantization/wrapq/examples/quantize_llama_mlp.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
from tico.quantization.wrapq.qscheme import QScheme
2828
from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
2929
from tico.utils.utils import SuppressWarning
30+
from tico.quantization.wrapq.dtypes import DType
31+
from tico.quantization.wrapq.mode import Mode
32+
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
33+
from tico.quantization.wrapq.observers.mx import MXObserver
34+
from tico.quantization.wrapq.qscheme import QScheme
3035

3136
name = "Maykeye/TinyLLama-v0"
3237
model = AutoModelForCausalLM.from_pretrained(name)
@@ -37,8 +42,19 @@
3742
# 1. Replace layer-0’s MLP with QuantLlamaMLP
3843
# -------------------------------------------------------------------------
3944
fp32_mlp = model.model.layers[0].mlp
45+
cfg = PTQConfig(
46+
default_dtype=DType.int(16),
47+
default_qscheme=QScheme.PER_TENSOR_SYMM,
48+
default_observer=MXObserver, #MinMaxObserver,
49+
overrides={
50+
"gate_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
51+
"up_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
52+
"down_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
53+
}
54+
)
55+
4056
model.model.layers[0].mlp = prepare(
41-
fp32_mlp, PTQConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM)
57+
fp32_mlp, cfg#PTQConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM)
4258
)
4359
model.eval()
4460

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,169 @@
1+
# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# =============================================================================
16+
# POST-TRAINING QUANTIZATION EXAMPLE — Llama Decoder Layer (Self-Attn + MLP)
17+
# -----------------------------------------------------------------------------
18+
# This demo shows how to:
19+
# 1. Replace a single FP32 `LlamaDecoderLayer` with `QuantLlamaDecoderLayer`.
20+
# 2. Collect activation statistics in one calibration sweep.
21+
# 3. Freeze scales / zero-points and switch to INT-simulation mode.
22+
# 4. Compare INT-8 vs FP32 outputs with a quick mean-absolute-diff check.
23+
# 5. Export the calibrated, quantized block to a Circle model.
24+
# -----------------------------------------------------------------------------
25+
# Style / layout is kept identical to the `quantize_llama_attn.py` and
26+
# `quantize_llama_mlp.py` examples for easy side-by-side reading.
27+
# =============================================================================
28+
29+
import pathlib
30+
31+
import torch
32+
from transformers import AutoModelForCausalLM, AutoTokenizer
33+
34+
from tico.quantization import convert, prepare
35+
from tico.quantization.config.ptq import PTQConfig
36+
from tico.quantization.evaluation.metric import compute_peir
37+
from tico.quantization.evaluation.utils import plot_two_outputs
38+
from tico.quantization.wrapq.dtypes import DType
39+
from tico.quantization.wrapq.mode import Mode
40+
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
41+
from tico.quantization.wrapq.observers.mx import MXObserver
42+
from tico.quantization.wrapq.qscheme import QScheme
43+
from tico.quantization.wrapq.wrappers.llama.quant_decoder_layer import (
44+
QuantLlamaDecoderLayer,
45+
)
46+
from tico.utils.utils import SuppressWarning
47+
48+
MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct" #"Maykeye/TinyLLama-v0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
49+
model = AutoModelForCausalLM.from_pretrained(
50+
MODEL_NAME, cache_dir="/mnt/storage/transformers_cache"
51+
)
52+
tokenizer = AutoTokenizer.from_pretrained(
53+
MODEL_NAME, cache_dir="/mnt/storage/transformers_cache"
54+
)
55+
model.config.max_position_embeddings = 2048 # we need this to prevent RAM exhaust
56+
model.config.use_cache = False
57+
58+
model.eval() # disable dropout, etc.
59+
rotary = model.model.rotary_emb # RoPE helper
60+
61+
# -------------------------------------------------------------------------
62+
# 1. Swap in the quant wrapper
63+
# -------------------------------------------------------------------------
64+
fp32_layer = model.model.layers[0] # keep a reference for diff check
65+
cfg = PTQConfig(
66+
default_dtype=DType.int(16),
67+
default_qscheme=QScheme.PER_TENSOR_SYMM,
68+
default_observer=MXObserver,#MinMaxObserver,
69+
overrides={
70+
"mlp": {
71+
"gate_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
72+
"up_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
73+
"down_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
74+
},
75+
"self_attn": {
76+
"q_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
77+
"k_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
78+
"v_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
79+
"o_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
80+
"scale": {"observer":MinMaxObserver},
81+
#"softmax": {"observer":MinMaxObserver},
82+
},
83+
"input_layernorm" : {"weight": {"dtype": DType.int(16), "observer":MinMaxObserver},
84+
#"act_in":{"observer":MinMaxObserver},
85+
# "act_out":{"observer":MinMaxObserver}
86+
},
87+
"post_attention_layernorm" : {"weight": {"dtype": DType.int(16), "observer":MinMaxObserver},
88+
# "act_in":{"observer":MinMaxObserver},
89+
# "act_out":{"observer":MinMaxObserver}
90+
},
91+
},
92+
)
93+
94+
model.model.layers[0] = prepare(fp32_layer, cfg)
95+
model.eval()
96+
97+
qlayer = model.model.layers[0] # alias for brevity
98+
assert isinstance(qlayer.wrapped, QuantLlamaDecoderLayer)
99+
100+
# -------------------------------------------------------------------------
101+
# 2. Single-pass calibration (gather activation ranges)
102+
# -------------------------------------------------------------------------
103+
PROMPTS = [
104+
"The quick brown fox jumps over the lazy dog.",
105+
"In 2025, AI systems accelerated hardware-software co-design at scale.",
106+
"양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.",
107+
"今日はいい天気ですね。ところでRoPE角度は長さに依存します。",
108+
"def quicksort(arr):\n if len(arr) <= 1: return arr\n ...",
109+
"Prices rose 3.14% — see Figure 2; emails: foo@bar.com!",
110+
]
111+
112+
with torch.no_grad():
113+
for prompt in PROMPTS:
114+
ids = tokenizer(prompt, return_tensors="pt")
115+
hidden = model.model.embed_tokens(ids["input_ids"])
116+
pos = rotary(hidden, ids["input_ids"]) # (cos, sin) tuple
117+
S = pos[0].shape[1]
118+
attn_mask = torch.zeros(1, 1, S, S) # causal-mask placeholder
119+
_ = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos)
120+
121+
convert(qlayer)
122+
123+
assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now."
124+
125+
# -------------------------------------------------------------------------
126+
# 3. Quick INT-sim vs FP32 sanity check
127+
# -------------------------------------------------------------------------
128+
ids = tokenizer("check", return_tensors="pt")
129+
hidden = model.model.embed_tokens(ids["input_ids"])
130+
pos = rotary(hidden, ids["input_ids"])
131+
S = pos[0].shape[1]
132+
attn_mask = torch.zeros(1, 1, S, S)
133+
134+
with torch.no_grad():
135+
int8_out = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos)
136+
int8 = int8_out[0] if isinstance(int8_out, tuple) else int8_out
137+
fp32_out = fp32_layer(hidden, attention_mask=attn_mask, position_embeddings=pos)
138+
fp32 = fp32_out[0] if isinstance(fp32_out, tuple) else fp32_out
139+
140+
print("┌───────────── Quantization Error Summary ─────────────")
141+
print(f"│ Mean |diff|: {(int8 - fp32).abs().mean().item():.6f}")
142+
print(f"│ PEIR : {compute_peir(fp32, int8) * 100:.6f} %")
143+
print("└──────────────────────────────────────────────────────")
144+
print(plot_two_outputs(fp32, int8))
145+
146+
# -------------------------------------------------------------------------
147+
# 4. Export the calibrated layer to Circle
148+
# -------------------------------------------------------------------------
149+
import tico
150+
151+
save_path = pathlib.Path(
152+
"decoder_layer.q.circle"
153+
) # "decoder_layer_unsloth_LLama_3_2_1B_RMS_NORM_A16W4.q.circle"
154+
B, S, D = 1, 4, model.config.hidden_size
155+
example_hidden = torch.randn(B, S, D)
156+
example_pos = rotary(example_hidden, torch.arange(S)[None, :])
157+
attn_mask = torch.zeros(1, 1, S, S)
158+
159+
with SuppressWarning(UserWarning, ".*"):
160+
cm = tico.convert(
161+
qlayer,
162+
(example_hidden, attn_mask),
163+
{"position_embeddings": example_pos},
164+
strict=False,
165+
)
166+
# Note that the model is not fully quantized.
167+
cm.save(save_path)
168+
169+
print(f"Quantized Circle model saved to {save_path.resolve()}")

0 commit comments

Comments
 (0)