Skip to content

Commit f7bb4d9

Browse files
committed
[quantization] Full quantization
This draft tries to get fully quantized model. TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 41ba3b9 commit f7bb4d9

25 files changed

+1525
-36
lines changed

tico/passes/decompose_fake_quantize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,27 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
123123
)
124124
node.replace_all_uses_with(dequnt, propagate_meta=True)
125125
modified = True
126+
127+
if node.target in [torch.ops.circle_custom.quantize_mx.default]:
128+
# tensor, elem_format, axis
129+
assert len(node.args) == 3
130+
_, elem_format, axis = node.args
131+
132+
with gm.graph.inserting_before(node):
133+
quant = create_node(
134+
g,
135+
torch.ops.circle_custom.quantize_mx_decomposed.default,
136+
args=node.args,
137+
origin=node,
138+
)
139+
dequnt = create_node(
140+
g,
141+
torch.ops.circle_custom.dequantize_mx_decomposed.default,
142+
args=(quant, *quant.args[1:]),
143+
kwargs=quant.kwargs,
144+
)
145+
node.replace_all_uses_with(dequnt, propagate_meta=True)
146+
modified = True
126147

127148
gm.graph.eliminate_dead_code()
128149
gm.graph.lint()

tico/quantization/passes/fold_quant_ops.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,50 @@
3030
QuantizePerTensorArgs,
3131
)
3232

33+
import copy
34+
from tico.utils.graph import create_node
35+
from tico.utils.utils import quant_min_max, set_new_meta_val
36+
from tico.quantization.passes.insert_quantize_on_dtype_mismatch import qparam_dtype
37+
38+
def _insert_mx_quantize_op(node, qparam):
39+
graph = node.graph
40+
assert qparam.quantized_dimension is not None
41+
assert qparam.dtype is not None
42+
43+
with graph.inserting_after(node):
44+
q_args = (node, qparam.dtype, qparam.quantized_dimension)
45+
quantize = create_node(
46+
graph,
47+
torch.ops.circle_custom.quantize_mx_decomposed.default,
48+
args=q_args,
49+
)
50+
51+
node.replace_all_uses_with(quantize, propagate_meta=True)
52+
quantize.replace_input_with(quantize, node)
53+
54+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
55+
56+
return quantize
57+
58+
def _insert_quantize_op(node, qparam):
59+
graph = node.graph
60+
min_, max_ = quant_min_max(qparam.dtype)
61+
dtype = getattr(torch, qparam.dtype)
62+
63+
with graph.inserting_after(node):
64+
q_args = (node, qparam.scale[0], qparam.zero_point[0], min_, max_, dtype)
65+
quantize = create_node(
66+
graph,
67+
torch.ops.quantized_decomposed.quantize_per_tensor.default,
68+
args=q_args,
69+
)
70+
71+
node.replace_all_uses_with(quantize, propagate_meta=True)
72+
quantize.replace_input_with(quantize, node)
73+
74+
quantize.meta[QPARAM_KEY] = copy.deepcopy(qparam)
75+
76+
return quantize
3377

3478
@trace_graph_diff_on_pass
3579
class FoldQuantOps(PassBase):
@@ -114,6 +158,13 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
114158
dq.replace_all_uses_with(op, propagate_meta=False)
115159

116160
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
161+
assert(QPARAM_KEY not in dq.meta) # we should not abandon quantization calibrated parameters
162+
#if QPARAM_KEY in dq.meta: #right now it's not needed
163+
# if (qparam_dtype(op) == "int16" or qparam_dtype(op) == "uint8") and qparam_dtype(dq) == "mxint8":
164+
# #need to insert requantization
165+
# assert(False)
166+
# _insert_mx_quantize_op(op, dq.meta[QPARAM_KEY])
167+
117168
# ───────────────────────────────────────────
118169
# Case 2: op already quantized
119170
# 2.1 same dtype → nothing to do
@@ -144,7 +195,83 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
144195
)
145196
dq.replace_all_uses_with(op, propagate_meta=False)
146197
logger.debug(f"Removed redundant {dq.name}")
198+
199+
for dq in graph.nodes:
200+
if dq.op != "call_function":
201+
continue
202+
if (
203+
dq.target
204+
!= torch.ops.circle_custom.dequantize_mx_decomposed.default
205+
):
206+
continue
207+
208+
dq_args = dq.args
209+
210+
q = dq_args[0]
211+
if q.target != torch.ops.circle_custom.quantize_mx_decomposed.default:
212+
continue
213+
q_args = q.args
214+
op = q_args[0]
215+
216+
# Check if Q and DQ have same parameters
217+
if q_args[1] != dq_args[1]:
218+
continue
219+
if q_args[2] != dq_args[2]:
220+
continue
221+
222+
# ───────────────────────────────────────────
223+
# Case 1: op not yet quantized
224+
# ───────────────────────────────────────────
225+
if QPARAM_KEY not in op.meta:
226+
#TODO
227+
qparam = QuantParam()
228+
qparam.dtype = "mxint8"# q_args[1] #TODO
229+
qparam.quantized_dimension = q_args[2]
230+
op.meta[QPARAM_KEY] = qparam
231+
232+
dq.replace_all_uses_with(op, propagate_meta=False)
233+
234+
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
235+
if QPARAM_KEY in dq.meta:
236+
if qparam_dtype(op) == "mxint8" and (qparam_dtype(dq) == "int16" or qparam_dtype(dq) == "uint8"):
237+
#need to insert requantization
238+
_insert_quantize_op(op, dq.meta[QPARAM_KEY])
147239

240+
# ───────────────────────────────────────────
241+
# Case 2: op already quantized
242+
# 2.1 same dtype → nothing to do
243+
# 2.2 diff dtype → leave Q in place
244+
# ───────────────────────────────────────────
245+
else:
246+
op_qparam: QuantParam = op.meta[QPARAM_KEY]
247+
qdq_dtype = "mxint8"#q_args[1] #TODO
248+
249+
if op_qparam.dtype != qdq_dtype:
250+
# Attach QPARAM to Q once
251+
if QPARAM_KEY not in q.meta:
252+
qparam = QuantParam()
253+
qparam.dtype = qdq_dtype
254+
qparam.quantized_dimension = q_args[2]
255+
q.meta[QPARAM_KEY] = qparam
256+
assert len(q.users) == 1, "Fix me unless"
257+
258+
dq.replace_all_uses_with(q, propagate_meta=False)
259+
logger.debug(f"{dq.name} is folded ({q.name} is left).")
260+
else:
261+
# Same dtype → the Quantize–Dequantize pair is redundant.
262+
assert not op_qparam.scale
263+
assert not op_qparam.zero_point
264+
assert (
265+
op_qparam.dtype
266+
and op_qparam.dtype == 'mxint8' #TODO
267+
)
268+
assert (
269+
op_qparam.quantized_dimension is not None
270+
and op_qparam.quantized_dimension == q_args[2]
271+
)
272+
dq.replace_all_uses_with(op, propagate_meta=False)
273+
logger.debug(f"Removed redundant {dq.name}")
274+
148275
graph.eliminate_dead_code()
149276
graph.lint()
150277
graph_module.recompile()

0 commit comments

Comments
 (0)