|
| 1 | +# Copyright (c) 2025 Samsung Electronics Co., Ltd. All Rights Reserved |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# ============================================================================= |
| 16 | +# POST-TRAINING QUANTIZATION EXAMPLE — Llama Decoder Layer (Self-Attn + MLP) |
| 17 | +# ----------------------------------------------------------------------------- |
| 18 | +# This demo shows how to: |
| 19 | +# 1. Replace a single FP32 `LlamaDecoderLayer` with `QuantLlamaDecoderLayer`. |
| 20 | +# 2. Collect activation statistics in one calibration sweep. |
| 21 | +# 3. Freeze scales / zero-points and switch to INT-simulation mode. |
| 22 | +# 4. Compare INT-8 vs FP32 outputs with a quick mean-absolute-diff check. |
| 23 | +# 5. Export the calibrated, quantized block to a Circle model. |
| 24 | +# ----------------------------------------------------------------------------- |
| 25 | +# Style / layout is kept identical to the `quantize_llama_attn.py` and |
| 26 | +# `quantize_llama_mlp.py` examples for easy side-by-side reading. |
| 27 | +# ============================================================================= |
| 28 | + |
| 29 | +import pathlib |
| 30 | + |
| 31 | +import torch |
| 32 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 33 | + |
| 34 | +from tico.quantization import convert, prepare |
| 35 | +from tico.quantization.config.ptq import PTQConfig |
| 36 | +from tico.quantization.evaluation.metric import compute_peir |
| 37 | +from tico.quantization.evaluation.utils import plot_two_outputs |
| 38 | +from tico.quantization.wrapq.dtypes import DType |
| 39 | +from tico.quantization.wrapq.mode import Mode |
| 40 | +from tico.quantization.wrapq.observers.minmax import MinMaxObserver |
| 41 | +from tico.quantization.wrapq.observers.mx import MXObserver |
| 42 | +from tico.quantization.wrapq.qscheme import QScheme |
| 43 | +from tico.quantization.wrapq.wrappers.llama.quant_decoder_layer import ( |
| 44 | + QuantLlamaDecoderLayer, |
| 45 | +) |
| 46 | +from tico.utils.utils import SuppressWarning |
| 47 | + |
| 48 | +MODEL_NAME = "unsloth/Llama-3.2-3B-Instruct" #"Maykeye/TinyLLama-v0" # "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
| 49 | +model = AutoModelForCausalLM.from_pretrained( |
| 50 | + MODEL_NAME, cache_dir="/mnt/storage/transformers_cache" |
| 51 | +) |
| 52 | +tokenizer = AutoTokenizer.from_pretrained( |
| 53 | + MODEL_NAME, cache_dir="/mnt/storage/transformers_cache" |
| 54 | +) |
| 55 | +model.config.max_position_embeddings = 2048 # we need this to prevent RAM exhaust |
| 56 | +model.config.use_cache = False |
| 57 | + |
| 58 | +model.eval() # disable dropout, etc. |
| 59 | +rotary = model.model.rotary_emb # RoPE helper |
| 60 | + |
| 61 | +# ------------------------------------------------------------------------- |
| 62 | +# 1. Swap in the quant wrapper |
| 63 | +# ------------------------------------------------------------------------- |
| 64 | +fp32_layer = model.model.layers[0] # keep a reference for diff check |
| 65 | +cfg = PTQConfig( |
| 66 | + default_dtype=DType.int(16), |
| 67 | + default_qscheme=QScheme.PER_TENSOR_SYMM, |
| 68 | + default_observer=MXObserver,#MinMaxObserver, |
| 69 | + overrides={ |
| 70 | + "mlp": { |
| 71 | + "gate_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 72 | + "up_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 73 | + "down_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 74 | + }, |
| 75 | + "self_attn": { |
| 76 | + "q_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 77 | + "k_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 78 | + "v_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 79 | + "o_proj": {"weight": {"dtype": DType.uint(4), "observer":MinMaxObserver}}, |
| 80 | + "scale": {"observer":MinMaxObserver}, |
| 81 | + #"softmax": {"observer":MinMaxObserver}, |
| 82 | + }, |
| 83 | + "input_layernorm" : {"weight": {"dtype": DType.int(16), "observer":MinMaxObserver}, |
| 84 | + #"act_in":{"observer":MinMaxObserver}, |
| 85 | + # "act_out":{"observer":MinMaxObserver} |
| 86 | + }, |
| 87 | + "post_attention_layernorm" : {"weight": {"dtype": DType.int(16), "observer":MinMaxObserver}, |
| 88 | + # "act_in":{"observer":MinMaxObserver}, |
| 89 | + # "act_out":{"observer":MinMaxObserver} |
| 90 | + }, |
| 91 | + }, |
| 92 | +) |
| 93 | + |
| 94 | +model.model.layers[0] = prepare(fp32_layer, cfg) |
| 95 | +model.eval() |
| 96 | + |
| 97 | +qlayer = model.model.layers[0] # alias for brevity |
| 98 | +assert isinstance(qlayer.wrapped, QuantLlamaDecoderLayer) |
| 99 | + |
| 100 | +# ------------------------------------------------------------------------- |
| 101 | +# 2. Single-pass calibration (gather activation ranges) |
| 102 | +# ------------------------------------------------------------------------- |
| 103 | +PROMPTS = [ |
| 104 | + "The quick brown fox jumps over the lazy dog.", |
| 105 | + "In 2025, AI systems accelerated hardware-software co-design at scale.", |
| 106 | + "양자화는 왜 어려울까? 분포, 길이, 마스크가 관건이다.", |
| 107 | + "今日はいい天気ですね。ところでRoPE角度は長さに依存します。", |
| 108 | + "def quicksort(arr):\n if len(arr) <= 1: return arr\n ...", |
| 109 | + "Prices rose 3.14% — see Figure 2; emails: foo@bar.com!", |
| 110 | +] |
| 111 | + |
| 112 | +with torch.no_grad(): |
| 113 | + for prompt in PROMPTS: |
| 114 | + ids = tokenizer(prompt, return_tensors="pt") |
| 115 | + hidden = model.model.embed_tokens(ids["input_ids"]) |
| 116 | + pos = rotary(hidden, ids["input_ids"]) # (cos, sin) tuple |
| 117 | + S = pos[0].shape[1] |
| 118 | + attn_mask = torch.zeros(1, 1, S, S) # causal-mask placeholder |
| 119 | + _ = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos) |
| 120 | + |
| 121 | +convert(qlayer) |
| 122 | + |
| 123 | +assert qlayer._mode is Mode.QUANT, "Quantization mode should be active now." |
| 124 | + |
| 125 | +# ------------------------------------------------------------------------- |
| 126 | +# 3. Quick INT-sim vs FP32 sanity check |
| 127 | +# ------------------------------------------------------------------------- |
| 128 | +ids = tokenizer("check", return_tensors="pt") |
| 129 | +hidden = model.model.embed_tokens(ids["input_ids"]) |
| 130 | +pos = rotary(hidden, ids["input_ids"]) |
| 131 | +S = pos[0].shape[1] |
| 132 | +attn_mask = torch.zeros(1, 1, S, S) |
| 133 | + |
| 134 | +with torch.no_grad(): |
| 135 | + int8_out = qlayer(hidden, attention_mask=attn_mask, position_embeddings=pos) |
| 136 | + int8 = int8_out[0] if isinstance(int8_out, tuple) else int8_out |
| 137 | + fp32_out = fp32_layer(hidden, attention_mask=attn_mask, position_embeddings=pos) |
| 138 | + fp32 = fp32_out[0] if isinstance(fp32_out, tuple) else fp32_out |
| 139 | + |
| 140 | +print("┌───────────── Quantization Error Summary ─────────────") |
| 141 | +print(f"│ Mean |diff|: {(int8 - fp32).abs().mean().item():.6f}") |
| 142 | +print(f"│ PEIR : {compute_peir(fp32, int8) * 100:.6f} %") |
| 143 | +print("└──────────────────────────────────────────────────────") |
| 144 | +print(plot_two_outputs(fp32, int8)) |
| 145 | + |
| 146 | +# ------------------------------------------------------------------------- |
| 147 | +# 4. Export the calibrated layer to Circle |
| 148 | +# ------------------------------------------------------------------------- |
| 149 | +import tico |
| 150 | + |
| 151 | +save_path = pathlib.Path( |
| 152 | + "decoder_layer.q.circle" |
| 153 | +) # "decoder_layer_unsloth_LLama_3_2_1B_RMS_NORM_A16W4.q.circle" |
| 154 | +B, S, D = 1, 4, model.config.hidden_size |
| 155 | +example_hidden = torch.randn(B, S, D) |
| 156 | +example_pos = rotary(example_hidden, torch.arange(S)[None, :]) |
| 157 | +attn_mask = torch.zeros(1, 1, S, S) |
| 158 | + |
| 159 | +with SuppressWarning(UserWarning, ".*"): |
| 160 | + cm = tico.convert( |
| 161 | + qlayer, |
| 162 | + (example_hidden, attn_mask), |
| 163 | + {"position_embeddings": example_pos}, |
| 164 | + strict=False, |
| 165 | + ) |
| 166 | +# Note that the model is not fully quantized. |
| 167 | +cm.save(save_path) |
| 168 | + |
| 169 | +print(f"Quantized Circle model saved to {save_path.resolve()}") |
0 commit comments