Skip to content

Commit 0253cb9

Browse files
committed
Serialize mx.
TICO-DCO-1.0-Signed-off-by: s.malakhov <s.malakhov@partner.samsung.com>
1 parent 221cb5c commit 0253cb9

File tree

9 files changed

+225
-22
lines changed

9 files changed

+225
-22
lines changed

tico/passes/decompose_fake_quantize.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,27 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
123123
)
124124
node.replace_all_uses_with(dequnt, propagate_meta=True)
125125
modified = True
126+
127+
if node.target in [torch.ops.circle_custom.quantize_mx.default]:
128+
# tensor, scale, zero_p, quant_min, quant_max
129+
assert len(node.args) == 3
130+
_, elem_format, axis = node.args
131+
132+
with gm.graph.inserting_before(node):
133+
quant = create_node(
134+
g,
135+
torch.ops.circle_custom.quantize_float_to_mx.default,
136+
args=node.args,
137+
origin=node,
138+
)
139+
dequnt = create_node(
140+
g,
141+
torch.ops.circle_custom.dequantize_mx_to_float.default,
142+
args=(quant, *quant.args[1:]),
143+
kwargs=quant.kwargs,
144+
)
145+
node.replace_all_uses_with(dequnt, propagate_meta=True)
146+
modified = True
126147

127148
gm.graph.eliminate_dead_code()
128149
gm.graph.lint()

tico/quantization/passes/fold_quant_ops.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,78 @@ def call(self, exported_program: ExportedProgram) -> PassResult:
144144
)
145145
dq.replace_all_uses_with(op, propagate_meta=False)
146146
logger.debug(f"Removed redundant {dq.name}")
147+
148+
for dq in graph.nodes:
149+
if dq.op != "call_function":
150+
continue
151+
if (
152+
dq.target
153+
!= torch.ops.circle_custom.dequantize_mx_to_float.default
154+
):
155+
continue
156+
157+
dq_args = dq.args
158+
159+
q = dq_args[0]
160+
if q.target != torch.ops.circle_custom.quantize_float_to_mx.default:
161+
continue
162+
q_args = q.args
163+
op = q_args[0]
164+
165+
# Check if Q and DQ have same parameters
166+
if q_args[1] != dq_args[1]:
167+
continue
168+
if q_args[2] != dq_args[2]:
169+
continue
170+
171+
# ───────────────────────────────────────────
172+
# Case 1: op not yet quantized
173+
# ───────────────────────────────────────────
174+
if QPARAM_KEY not in op.meta:
175+
#TODO
176+
qparam = QuantParam()
177+
qparam.dtype = "mxint8"# q_args[1] #TODO
178+
qparam.quantized_dimension = q_args[2]
179+
op.meta[QPARAM_KEY] = qparam
180+
181+
dq.replace_all_uses_with(op, propagate_meta=False)
182+
183+
logger.debug(f"{q.name} and {dq.name} are folded to {op.name}.")
184+
# ───────────────────────────────────────────
185+
# Case 2: op already quantized
186+
# 2.1 same dtype → nothing to do
187+
# 2.2 diff dtype → leave Q in place
188+
# ───────────────────────────────────────────
189+
else:
190+
op_qparam: QuantParam = op.meta[QPARAM_KEY]
191+
qdq_dtype = "mxint8"#q_args[1] #TODO
192+
193+
if op_qparam.dtype != qdq_dtype:
194+
# Attach QPARAM to Q once
195+
if QPARAM_KEY not in q.meta:
196+
qparam = QuantParam()
197+
qparam.dtype = qdq_dtype
198+
qparam.quantized_dimension = q_args[2]
199+
q.meta[QPARAM_KEY] = qparam
200+
assert len(q.users) == 1, "Fix me unless"
147201

202+
dq.replace_all_uses_with(q, propagate_meta=False)
203+
logger.debug(f"{dq.name} is folded ({q.name} is left).")
204+
else:
205+
# Same dtype → the Quantize–Dequantize pair is redundant.
206+
assert not op_qparam.scale
207+
assert not op_qparam.zero_point
208+
assert (
209+
op_qparam.dtype
210+
and op_qparam.dtype == 'mxint8' #TODO
211+
)
212+
assert (
213+
op_qparam.quantized_dimension is not None
214+
and op_qparam.quantized_dimension == q_args[2]
215+
)
216+
dq.replace_all_uses_with(op, propagate_meta=False)
217+
logger.debug(f"Removed redundant {dq.name}")
218+
148219
graph.eliminate_dead_code()
149220
graph.lint()
150221
graph_module.recompile()

tico/quantization/wrapq/examples/quantize_linear.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@
3636
from tico.quantization.wrapq.mode import Mode
3737
from tico.quantization.wrapq.wrappers.nn.quant_linear import QuantLinear
3838
from tico.utils.utils import SuppressWarning
39+
from tico.quantization.wrapq.dtypes import DType
40+
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
41+
from tico.quantization.wrapq.observers.mx import MXObserver
42+
from tico.quantization.wrapq.qscheme import QScheme
3943

4044

4145
# -------------------------------------------------------------------------
@@ -62,7 +66,14 @@ def forward(self, x):
6266
# -------------------------------------------------------------------------
6367
# 1. Replace the Linear with QuantLinear wrapper
6468
# -------------------------------------------------------------------------
65-
model.fc = prepare(fp32_layer, PTQConfig()) # type: ignore[assignment]
69+
cfg = PTQConfig(
70+
default_dtype=DType.uint(8),
71+
default_qscheme=QScheme.PER_TENSOR_ASYMM,
72+
default_observer=MXObserver,#MinMaxObserver,
73+
overrides = {"weight": {"dtype": DType.uint(4), "observer": MinMaxObserver}},
74+
)
75+
76+
model.fc = prepare(fp32_layer,cfg) # type: ignore[assignment]
6677
qlayer = model.fc # alias for brevity
6778

6879
# -------------------------------------------------------------------------

tico/quantization/wrapq/examples/quantize_llama_mlp.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
from tico.quantization.wrapq.qscheme import QScheme
2828
from tico.quantization.wrapq.wrappers.llama.quant_mlp import QuantLlamaMLP
2929
from tico.utils.utils import SuppressWarning
30+
from tico.quantization.wrapq.dtypes import DType
31+
from tico.quantization.wrapq.mode import Mode
32+
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
33+
from tico.quantization.wrapq.observers.mx import MXObserver
34+
from tico.quantization.wrapq.qscheme import QScheme
3035

3136
name = "Maykeye/TinyLLama-v0"
3237
model = AutoModelForCausalLM.from_pretrained(name)
@@ -37,8 +42,19 @@
3742
# 1. Replace layer-0’s MLP with QuantLlamaMLP
3843
# -------------------------------------------------------------------------
3944
fp32_mlp = model.model.layers[0].mlp
45+
cfg = PTQConfig(
46+
default_dtype=DType.int(16),
47+
default_qscheme=QScheme.PER_TENSOR_SYMM,
48+
default_observer=MXObserver, #MinMaxObserver,
49+
overrides={
50+
"gate_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
51+
"up_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
52+
"down_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
53+
}
54+
)
55+
4056
model.model.layers[0].mlp = prepare(
41-
fp32_mlp, PTQConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM)
57+
fp32_mlp, cfg#PTQConfig(default_dtype=INT16, default_qscheme=QScheme.PER_TENSOR_SYMM)
4258
)
4359
model.eval()
4460

tico/quantization/wrapq/examples/quantize_llama_whole_decoder_layer.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,14 @@
3838
from tico.quantization.wrapq.dtypes import DType
3939
from tico.quantization.wrapq.mode import Mode
4040
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
41+
from tico.quantization.wrapq.observers.mx import MXObserver
4142
from tico.quantization.wrapq.qscheme import QScheme
4243
from tico.quantization.wrapq.wrappers.llama.quant_decoder_layer import (
4344
QuantLlamaDecoderLayer,
4445
)
4546
from tico.utils.utils import SuppressWarning
4647

47-
MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct" # "Maykeye/TinyLLama-v0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
48+
MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct" #"Maykeye/TinyLLama-v0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
4849
model = AutoModelForCausalLM.from_pretrained(
4950
MODEL_NAME, cache_dir="/mnt/storage/transformers_cache"
5051
)
@@ -64,20 +65,23 @@
6465
cfg = PTQConfig(
6566
default_dtype=DType.int(16),
6667
default_qscheme=QScheme.PER_TENSOR_SYMM,
67-
default_observer=MinMaxObserver,
68+
default_observer=MXObserver, #MinMaxObserver
6869
overrides={
6970
# local override: input observer now MinMax & 4-bit, per-channel asymmetric
7071
"mlp": {
71-
"gate_proj": {"weight": {"dtype": DType.uint(4)}},
72-
"up_proj": {"weight": {"dtype": DType.uint(4)}},
73-
"down_proj": {"weight": {"dtype": DType.uint(4)}},
72+
"gate_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
73+
"up_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
74+
"down_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
7475
},
7576
"self_attn": {
76-
"q_proj": {"weight": {"dtype": DType.uint(4)}},
77-
"k_proj": {"weight": {"dtype": DType.uint(4)}},
78-
"v_proj": {"weight": {"dtype": DType.uint(4)}},
79-
"o_proj": {"weight": {"dtype": DType.uint(4)}},
77+
"q_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
78+
"k_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
79+
"v_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
80+
"o_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
81+
"scale": {"observer":MinMaxObserver},
8082
},
83+
"input_layernorm" : {},
84+
"post_attention_layernorm" : {},
8185
},
8286
)
8387

tico/quantization/wrapq/examples/quantize_with_gptq.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
from tico.quantization.wrapq.dtypes import DType
4343
from tico.quantization.wrapq.observers.affine_base import AffineObserverBase
4444
from tico.quantization.wrapq.observers.minmax import MinMaxObserver
45+
from tico.quantization.wrapq.observers.mx import MXObserver
4546
from tico.quantization.wrapq.qscheme import QScheme
4647
from tico.quantization.wrapq.utils.introspection import build_fqn_map
4748
from tico.quantization.wrapq.utils.metrics import perplexity
@@ -246,26 +247,26 @@ def main():
246247
print("Wrapping layers with PTQWrapper …")
247248
w_cfg = {
248249
"mlp": {
249-
"gate_proj": {"weight": {"dtype": DType.uint(4)}},
250-
"up_proj": {"weight": {"dtype": DType.uint(4)}},
251-
"down_proj": {"weight": {"dtype": DType.uint(4)}},
250+
"gate_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
251+
"up_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
252+
"down_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
252253
},
253254
"self_attn": {
254-
"q_proj": {"weight": {"dtype": DType.uint(4)}},
255-
"k_proj": {"weight": {"dtype": DType.uint(4)}},
256-
"v_proj": {"weight": {"dtype": DType.uint(4)}},
257-
"o_proj": {"weight": {"dtype": DType.uint(4)}},
255+
"q_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
256+
"k_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
257+
"v_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
258+
"o_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}},
258259
},
259260
}
260261
cfg = PTQConfig(
261-
default_dtype=DType.int(16),
262+
default_dtype=DType.int(8),
262263
default_qscheme=QScheme.PER_TENSOR_SYMM,
263-
default_observer=MinMaxObserver,
264+
default_observer=MXObserver,#MinMaxObserver,
264265
overrides={
265266
"model.embeddings": {
266-
"weight": {"dtype": DType.uint(8)}
267+
"weight": {"dtype": DType.uint(8), "observer":MinMaxObserver},
267268
}, # embeddings to 8-bits
268-
"lm_head": {"weight": {"dtype": DType.uint(8)}}, # lm_head to 8-bits
269+
"lm_head": {"weight": {"dtype": DType.uint(8), "observer":MinMaxObserver}}, # lm_head to 8-bits
269270
},
270271
)
271272
for i in range(len(q_m.model.layers)):

tico/serialize/circle_mapping.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ def str_to_circle_dtype(
6363
"int64": circle.TensorType.TensorType.INT64,
6464
"bool": circle.TensorType.TensorType.BOOL,
6565
"uint4": circle.TensorType.TensorType.UINT4,
66+
"mxint8": circle.TensorType.TensorType.MXINT8,
67+
"mxfp4": circle.TensorType.TensorType.MXFP4,
6668
# TODO Add more dtypes
6769
}
6870

tico/utils/register_custom_op.py

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -703,7 +703,80 @@ def _(
703703
round: str = "nearest", # Fixed
704704
) -> torch.Tensor:
705705
return input_
706+
707+
def CircleQuantizeFloatToMX():
708+
#TODO
709+
@custom_op("circle_custom::quantize_float_to_mx", mutates_args=())
710+
def quantize_mx(
711+
input_: torch.Tensor,
712+
elem_format: str,
713+
axis: int,
714+
shared_exp_method: str = "max",
715+
round: str = "nearest",
716+
) -> torch.Tensor:
717+
if elem_format == "int8":
718+
scale_bits = 8
719+
block_size = 32
720+
else:
721+
raise RuntimeError(f"Unsupported elem_format in quantize_mx: {elem_format}")
722+
723+
result = _quantize_mx(
724+
input_,
725+
scale_bits=scale_bits,
726+
elem_format=elem_format,
727+
axes=[axis],
728+
block_size=block_size,
729+
shared_exp_method=shared_exp_method,
730+
round=round,
731+
)
732+
return result
733+
734+
@register_fake("circle_custom::quantize_float_to_mx")
735+
def _(
736+
input_: torch.Tensor,
737+
elem_format: str,
738+
axis: int,
739+
shared_exp_method: str = "max", # Fixed
740+
round: str = "nearest", # Fixed
741+
) -> torch.Tensor:
742+
return input_
706743

744+
def CircleDeQuantizeMXToFloat():
745+
#TODO
746+
@custom_op("circle_custom::dequantize_mx_to_float", mutates_args=())
747+
def quantize_mx(
748+
input_: torch.Tensor,
749+
elem_format: str,
750+
axis: int,
751+
shared_exp_method: str = "max",
752+
round: str = "nearest",
753+
) -> torch.Tensor:
754+
if elem_format == "int8":
755+
scale_bits = 8
756+
block_size = 32
757+
else:
758+
raise RuntimeError(f"Unsupported elem_format in quantize_mx: {elem_format}")
759+
760+
result = _quantize_mx(
761+
input_,
762+
scale_bits=scale_bits,
763+
elem_format=elem_format,
764+
axes=[axis],
765+
block_size=block_size,
766+
shared_exp_method=shared_exp_method,
767+
round=round,
768+
)
769+
return result
770+
771+
@register_fake("circle_custom::dequantize_mx_to_float")
772+
def _(
773+
input_: torch.Tensor,
774+
elem_format: str,
775+
axis: int,
776+
shared_exp_method: str = "max", # Fixed
777+
round: str = "nearest", # Fixed
778+
) -> torch.Tensor:
779+
return input_
707780

708781
def CircleRMSNorm():
709782
@custom_op("circle_custom::rms_norm", mutates_args=())
@@ -800,6 +873,8 @@ def RegisterOps():
800873
CircleAvgPool2D()
801874
CircleInstanceNorm()
802875
CircleQuantizeMX()
876+
CircleQuantizeFloatToMX()
877+
CircleDeQuantizeMXToFloat()
803878
CircleRMSNorm()
804879
CircleAttention()
805880
CircleShape()

tico/utils/utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,8 @@ def has_quantization_ops(graph: torch.fx.Graph):
268268
torch.ops.quantized_decomposed.quantize_per_channel.default,
269269
torch.ops.quantized_decomposed.dequantize_per_tensor.default,
270270
torch.ops.quantized_decomposed.dequantize_per_channel.default,
271+
torch.ops.circle_custom.quantize_float_to_mx.default,
272+
torch.ops.circle_custom.dequantize_mx_to_float.default,
271273
]
272274
for node in graph.nodes:
273275
if node.op != "call_function":

0 commit comments

Comments
 (0)