|
42 | 42 | from tico.quantization.wrapq.dtypes import DType |
43 | 43 | from tico.quantization.wrapq.observers.affine_base import AffineObserverBase |
44 | 44 | from tico.quantization.wrapq.observers.minmax import MinMaxObserver |
| 45 | +from tico.quantization.wrapq.observers.mx import MXObserver |
45 | 46 | from tico.quantization.wrapq.qscheme import QScheme |
46 | 47 | from tico.quantization.wrapq.utils.introspection import build_fqn_map |
47 | 48 | from tico.quantization.wrapq.utils.metrics import perplexity |
@@ -246,26 +247,26 @@ def main(): |
246 | 247 | print("Wrapping layers with PTQWrapper …") |
247 | 248 | w_cfg = { |
248 | 249 | "mlp": { |
249 | | - "gate_proj": {"weight": {"dtype": DType.uint(4)}}, |
250 | | - "up_proj": {"weight": {"dtype": DType.uint(4)}}, |
251 | | - "down_proj": {"weight": {"dtype": DType.uint(4)}}, |
| 250 | + "gate_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 251 | + "up_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 252 | + "down_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
252 | 253 | }, |
253 | 254 | "self_attn": { |
254 | | - "q_proj": {"weight": {"dtype": DType.uint(4)}}, |
255 | | - "k_proj": {"weight": {"dtype": DType.uint(4)}}, |
256 | | - "v_proj": {"weight": {"dtype": DType.uint(4)}}, |
257 | | - "o_proj": {"weight": {"dtype": DType.uint(4)}}, |
| 255 | + "q_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 256 | + "k_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 257 | + "v_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 258 | + "o_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
258 | 259 | }, |
259 | 260 | } |
260 | 261 | cfg = PTQConfig( |
261 | | - default_dtype=DType.int(16), |
| 262 | + default_dtype=DType.int(8), |
262 | 263 | default_qscheme=QScheme.PER_TENSOR_SYMM, |
263 | | - default_observer=MinMaxObserver, |
| 264 | + default_observer=MXObserver,#MinMaxObserver, |
264 | 265 | overrides={ |
265 | 266 | "model.embeddings": { |
266 | | - "weight": {"dtype": DType.uint(8)} |
| 267 | + "weight": {"dtype": DType.uint(8), "observer":MinMaxObserver}, |
267 | 268 | }, # embeddings to 8-bits |
268 | | - "lm_head": {"weight": {"dtype": DType.uint(8)}}, # lm_head to 8-bits |
| 269 | + "lm_head": {"weight": {"dtype": DType.uint(8), "observer":MinMaxObserver}}, # lm_head to 8-bits |
269 | 270 | }, |
270 | 271 | ) |
271 | 272 | for i in range(len(q_m.model.layers)): |
|
0 commit comments