|
| 1 | +import torch |
| 2 | +import torch.functional as F |
| 3 | +from typing import Tuple |
| 4 | +import transformers |
| 5 | +from transformers import AutoModelForCausalLM, AutoTokenizer |
| 6 | +from datasets import load_dataset |
| 7 | +import re |
| 8 | + |
| 9 | +MODEL_ID = "facebook/opt-125m" |
| 10 | +# MODEL_ID = "echarlaix/tiny-random-mistral" |
| 11 | + |
| 12 | + |
| 13 | +NUM_PROMPTS = 512 |
| 14 | +MAX_SEQ_LEN = 512 |
| 15 | + |
| 16 | + |
| 17 | +# HACK: override the dtype_byte_size function in transformers to support float8 types |
| 18 | +def new_dtype_byte_size(dtype): |
| 19 | + if dtype == torch.bool: |
| 20 | + return 1 / 8 |
| 21 | + bit_search = re.search(r"[^\d](\d+)_?", str(dtype)) |
| 22 | + if bit_search is None: |
| 23 | + raise ValueError(f"`dtype` is not a valid dtype: {dtype}.") |
| 24 | + bit_size = int(bit_search.groups()[0]) |
| 25 | + return bit_size // 8 |
| 26 | + |
| 27 | + |
| 28 | +transformers.modeling_utils.dtype_byte_size = new_dtype_byte_size |
| 29 | + |
| 30 | + |
| 31 | +def per_tensor_quantize(tensor: torch.Tensor) -> Tuple[torch.Tensor, float]: |
| 32 | + """Quantize a tensor using per-tensor static scaling factor. |
| 33 | +
|
| 34 | + Args: |
| 35 | + tensor: The input tensor. |
| 36 | + """ |
| 37 | + finfo = torch.finfo(torch.float8_e4m3fn) |
| 38 | + # Calculate the scale as dtype max divided by absmax. |
| 39 | + # Since .abs() creates a new tensor, we use aminmax to get |
| 40 | + # the min and max first and then calculate the absmax. |
| 41 | + min_val, max_val = tensor.aminmax() |
| 42 | + amax = min_val.abs().max(max_val.abs()) |
| 43 | + scale = finfo.max / amax.clamp(min=1e-12) |
| 44 | + # scale and clamp the tensor to bring it to |
| 45 | + # the representative range of float8 data type |
| 46 | + # (as default cast is unsaturated) |
| 47 | + qweight = (tensor * scale).clamp(min=finfo.min, max=finfo.max) |
| 48 | + # Return both float8 data and the inverse scale (as float), |
| 49 | + # as both required as inputs to torch._scaled_mm |
| 50 | + qweight = qweight.to(torch.float8_e4m3fn) |
| 51 | + scale = scale.float().reciprocal() |
| 52 | + return qweight, scale |
| 53 | + |
| 54 | + |
| 55 | +def fp8_gemm(A, A_scale, B, B_scale, bias, out_dtype): |
| 56 | + cuda_compute_capability = torch.cuda.get_device_capability() |
| 57 | + if cuda_compute_capability >= (9, 0): |
| 58 | + output, _ = torch._scaled_mm( |
| 59 | + A, |
| 60 | + B.t(), |
| 61 | + out_dtype=out_dtype, |
| 62 | + scale_a=A_scale, |
| 63 | + scale_b=B_scale, |
| 64 | + bias=bias, |
| 65 | + ) |
| 66 | + else: |
| 67 | + output = torch.nn.functional.linear( |
| 68 | + A.to(out_dtype) * A_scale, |
| 69 | + B.to(out_dtype) * B_scale.to(out_dtype), |
| 70 | + bias=bias, |
| 71 | + ) |
| 72 | + return output |
| 73 | + |
| 74 | + |
| 75 | +class FP8StaticLinearQuantizer(torch.nn.Module): |
| 76 | + def __init__(self, qweight, weight_scale): |
| 77 | + super().__init__() |
| 78 | + self.weight = torch.nn.Parameter(qweight, requires_grad=False) |
| 79 | + self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) |
| 80 | + self.act_scale = None |
| 81 | + |
| 82 | + def forward(self, x): |
| 83 | + # Dynamically quantize |
| 84 | + qinput, x_act_scale = per_tensor_quantize(x) |
| 85 | + |
| 86 | + # Update scale if needed. |
| 87 | + if self.act_scale is None: |
| 88 | + self.act_scale = torch.nn.Parameter(x_act_scale) |
| 89 | + elif x_act_scale > self.act_scale: |
| 90 | + self.act_scale = torch.nn.Parameter(x_act_scale) |
| 91 | + |
| 92 | + # Pass quantized to next layer so it has realistic data. |
| 93 | + output = fp8_gemm( |
| 94 | + A=qinput, |
| 95 | + A_scale=self.act_scale, |
| 96 | + B=self.weight, |
| 97 | + B_scale=self.weight_scale, |
| 98 | + bias=None, |
| 99 | + out_dtype=x.dtype, |
| 100 | + ) |
| 101 | + return output |
| 102 | + |
| 103 | + |
| 104 | +class FP8StaticLinear(torch.nn.Module): |
| 105 | + def __init__(self, qweight, weight_scale, act_scale=0.0): |
| 106 | + super().__init__() |
| 107 | + self.weight = torch.nn.Parameter(qweight, requires_grad=False) |
| 108 | + self.weight_scale = torch.nn.Parameter(weight_scale, requires_grad=False) |
| 109 | + self.act_scale = torch.nn.Parameter(act_scale, requires_grad=False) |
| 110 | + |
| 111 | + def per_tensor_quantize( |
| 112 | + self, tensor: torch.Tensor, inv_scale: float |
| 113 | + ) -> torch.Tensor: |
| 114 | + # Scale and clamp the tensor to bring it to |
| 115 | + # the representative range of float8 data type |
| 116 | + # (as default cast is unsaturated) |
| 117 | + finfo = torch.finfo(torch.float8_e4m3fn) |
| 118 | + qweight = (tensor / inv_scale).clamp(min=finfo.min, max=finfo.max) |
| 119 | + return qweight.to(torch.float8_e4m3fn) |
| 120 | + |
| 121 | + def forward(self, x): |
| 122 | + qinput = self.per_tensor_quantize(x, inv_scale=self.act_scale) |
| 123 | + output = fp8_gemm( |
| 124 | + A=qinput, |
| 125 | + A_scale=self.act_scale, |
| 126 | + B=self.weight, |
| 127 | + B_scale=self.weight_scale, |
| 128 | + bias=None, |
| 129 | + out_dtype=x.dtype, |
| 130 | + ) |
| 131 | + return output |
| 132 | + |
| 133 | + |
| 134 | +class FP8DynamicLinear(torch.nn.Module): |
| 135 | + def __init__(self, qweight, scale): |
| 136 | + super().__init__() |
| 137 | + self.weight = torch.nn.Parameter(qweight, requires_grad=False) |
| 138 | + self.weight_scale = torch.nn.Parameter(scale, requires_grad=False) |
| 139 | + |
| 140 | + def forward(self, x): |
| 141 | + qinput, x_scale = per_tensor_quantize(x) |
| 142 | + output = fp8_gemm( |
| 143 | + A=qinput, |
| 144 | + A_scale=x_scale, |
| 145 | + B=self.weight, |
| 146 | + B_scale=self.weight_scale, |
| 147 | + bias=None, |
| 148 | + out_dtype=x.dtype, |
| 149 | + ) |
| 150 | + return output |
| 151 | + |
| 152 | + |
| 153 | +def quantize_weights(model): |
| 154 | + for name, linear in model.model.named_modules(): |
| 155 | + if not isinstance(linear, torch.nn.Linear): |
| 156 | + continue |
| 157 | + quant_weight, quant_scale = per_tensor_quantize(linear.weight) |
| 158 | + quant_linear = FP8DynamicLinear(quant_weight, quant_scale) |
| 159 | + if "." in name: |
| 160 | + parent_name = name.rsplit(".", 1)[0] |
| 161 | + child_name = name[len(parent_name) + 1 :] |
| 162 | + parent = model.model.get_submodule(parent_name) |
| 163 | + else: |
| 164 | + parent_name = "" |
| 165 | + parent = model.model |
| 166 | + child_name = name |
| 167 | + setattr(parent, child_name, quant_linear) |
| 168 | + |
| 169 | + |
| 170 | +def quantize_activations(model, calibration_tokens): |
| 171 | + # Replace layers with quantizer. |
| 172 | + for name, dynamic_quant_linear in model.model.named_modules(): |
| 173 | + if not isinstance(dynamic_quant_linear, FP8DynamicLinear): |
| 174 | + continue |
| 175 | + quantizer = FP8StaticLinearQuantizer( |
| 176 | + dynamic_quant_linear.weight, dynamic_quant_linear.weight_scale |
| 177 | + ) |
| 178 | + if "." in name: |
| 179 | + parent_name = name.rsplit(".", 1)[0] |
| 180 | + child_name = name[len(parent_name) + 1 :] |
| 181 | + parent = model.model.get_submodule(parent_name) |
| 182 | + else: |
| 183 | + parent_name = "" |
| 184 | + parent = model.model |
| 185 | + child_name = name |
| 186 | + setattr(parent, child_name, quantizer) |
| 187 | + |
| 188 | + # Calibration. |
| 189 | + for row_idx in range(calibration_tokens.shape[0]): |
| 190 | + _ = model(calibration_tokens[row_idx].reshape(1, -1)) |
| 191 | + |
| 192 | + # Replace quantizer with StaticLayer. |
| 193 | + for name, quantizer in model.model.named_modules(): |
| 194 | + if not isinstance(quantizer, FP8StaticLinearQuantizer): |
| 195 | + continue |
| 196 | + static_proj = FP8StaticLinear( |
| 197 | + quantizer.weight, quantizer.weight_scale, quantizer.act_scale |
| 198 | + ) |
| 199 | + if "." in name: |
| 200 | + parent_name = name.rsplit(".", 1)[0] |
| 201 | + child_name = name[len(parent_name) + 1 :] |
| 202 | + parent = model.model.get_submodule(parent_name) |
| 203 | + else: |
| 204 | + parent_name = "" |
| 205 | + parent = model.model |
| 206 | + child_name = name |
| 207 | + setattr(parent, child_name, static_proj) |
| 208 | + |
| 209 | + |
| 210 | +if __name__ == "__main__": |
| 211 | + tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
| 212 | + sample_input_tokens = tokenizer.apply_chat_template( |
| 213 | + [{"role": "user", "content": "What is your name?"}], |
| 214 | + add_generation_prompt=True, |
| 215 | + return_tensors="pt", |
| 216 | + ).to("cuda") |
| 217 | + |
| 218 | + ds = load_dataset("HuggingFaceH4/ultrachat_200k", split="train_sft") |
| 219 | + ds = ds.shuffle(seed=42).select(range(NUM_PROMPTS)) |
| 220 | + ds = ds.map( |
| 221 | + lambda batch: { |
| 222 | + "text": tokenizer.apply_chat_template(batch["messages"], tokenize=False) |
| 223 | + } |
| 224 | + ) |
| 225 | + tokenizer.pad_token_id = tokenizer.eos_token_id |
| 226 | + calibration_tokens = tokenizer( |
| 227 | + ds["text"], |
| 228 | + return_tensors="pt", |
| 229 | + truncation=True, |
| 230 | + padding="max_length", |
| 231 | + max_length=MAX_SEQ_LEN, |
| 232 | + add_special_tokens=False, |
| 233 | + ).input_ids.to("cuda") |
| 234 | + print("Calibration tokens:", calibration_tokens.shape) |
| 235 | + |
| 236 | + # Load and test the model |
| 237 | + model = AutoModelForCausalLM.from_pretrained( |
| 238 | + MODEL_ID, torch_dtype="auto", device_map="auto" |
| 239 | + ) |
| 240 | + output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) |
| 241 | + print("ORIGINAL:\n", tokenizer.decode(output[0]), "\n\n") |
| 242 | + |
| 243 | + # Quantize weights. |
| 244 | + quantize_weights(model) |
| 245 | + output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) |
| 246 | + print("WEIGHT QUANT:\n", tokenizer.decode(output[0]), "\n\n") |
| 247 | + |
| 248 | + # Quantize activations. |
| 249 | + quantize_activations(model, calibration_tokens=calibration_tokens) |
| 250 | + output = model.generate(input_ids=sample_input_tokens, max_new_tokens=20) |
| 251 | + print("ACT QUANT:\n", tokenizer.decode(output[0]), "\n\n") |
| 252 | + |
| 253 | + # Save the model fully quantized |
| 254 | + output_path = "fp8-static-quant" |
| 255 | + print(f"Saving the model to {output_path}") |
| 256 | + static_q_dict = {"quantization_config": {"quant_method": "fp8", "scheme": "static"}} |
| 257 | + model.config.update(static_q_dict) |
| 258 | + model.save_pretrained(output_path) |
| 259 | + tokenizer.save_pretrained(output_path) |
0 commit comments