|
| 1 | +# Copyright The FMS Model Optimizer Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# Standard |
| 16 | +from pathlib import Path |
| 17 | +import logging |
| 18 | + |
| 19 | +# Third Party |
| 20 | +from transformers.modeling_utils import PreTrainedModel |
| 21 | +import torch |
| 22 | + |
| 23 | +# logging is only enabled for verbose output (performance is less critical during debug), |
| 24 | +# and f-string style logging is preferred for code readability |
| 25 | +# pylint: disable=logging-not-lazy |
| 26 | + |
| 27 | + |
| 28 | +logger = logging.getLogger() |
| 29 | + |
| 30 | + |
| 31 | +def get_quantized_linear_names(model_type: str) -> list[str]: |
| 32 | + """Return a list of unique identifiers for the linear layers in a given model.""" |
| 33 | + |
| 34 | + if model_type in ["granite", "llama"]: |
| 35 | + return [ |
| 36 | + "self_attn.q_proj", |
| 37 | + "self_attn.k_proj", |
| 38 | + "self_attn.v_proj", |
| 39 | + "self_attn.o_proj", |
| 40 | + "mlp.gate_proj", |
| 41 | + "mlp.up_proj", |
| 42 | + "mlp.down_proj", |
| 43 | + ] |
| 44 | + if model_type == "gpt_bigcode": |
| 45 | + return [ |
| 46 | + "attn.c_attn", |
| 47 | + "attn.c_proj", |
| 48 | + "mlp.c_fc", |
| 49 | + "mlp.c_proj", |
| 50 | + ] |
| 51 | + if model_type in ["bert", "roberta"]: |
| 52 | + return [ |
| 53 | + "attention.self.query", |
| 54 | + "attention.self.key", |
| 55 | + "attention.self.value", |
| 56 | + "attention.output.dense", |
| 57 | + "intermediate.dense", |
| 58 | + "output.dense", |
| 59 | + ] |
| 60 | + raise NotImplementedError( |
| 61 | + f"Model type {model_type} is not supported for quantized checkpoint saving" |
| 62 | + ) |
| 63 | + |
| 64 | + |
| 65 | +def convert_sd_for_aiu( |
| 66 | + model: PreTrainedModel, |
| 67 | + verbose: bool, |
| 68 | +) -> dict[str, torch.Tensor]: |
| 69 | + """Convert the state dictionary (sd) of an FMS-MO-quantized model into a format |
| 70 | + compatible with the AIU. |
| 71 | +
|
| 72 | + Expected tensors in input state dictionary: |
| 73 | + weights: |
| 74 | + [out_feat, in_feat] |
| 75 | + w_cv: |
| 76 | + perT [1] |
| 77 | + perCh [out_feat] |
| 78 | + w_cvn = - w_cv <--- always symmetric! |
| 79 | + a_cv: |
| 80 | + per-token-max n/a |
| 81 | + perT [1] |
| 82 | + a_cvn: symmetric or asymmetric |
| 83 | +
|
| 84 | + Smoothquant combined scale is computed as: |
| 85 | + s_sq = a_sq_scale ^ alpha / w_sq_scale ^ (1- alpha) |
| 86 | +
|
| 87 | + All parameters except quantized weights are cast to FP16, per AIU requirement. |
| 88 | + """ |
| 89 | + |
| 90 | + if verbose: |
| 91 | + logger.info("Before conversion:") |
| 92 | + logger.info("* ALL MODEL PARAMETERS (name, size, dtype)") |
| 93 | + logger.info( |
| 94 | + "\n" |
| 95 | + + "\n".join( |
| 96 | + f"{k:80} {str(list(v.size())):15} {v.dtype}" |
| 97 | + for k, v in model.named_parameters() |
| 98 | + ) |
| 99 | + ) |
| 100 | + logger.info("* ALL BUFFERS (name, size, dtype)") |
| 101 | + logger.info( |
| 102 | + "\n" |
| 103 | + + "\n".join( |
| 104 | + f"{k:80} {str(list(v.size())):15} {v.dtype}" |
| 105 | + for k, v in model.named_buffers() |
| 106 | + ) |
| 107 | + ) |
| 108 | + logger.info("=" * 60) |
| 109 | + |
| 110 | + model_type = getattr(model.config, "model_type", None) |
| 111 | + if model_type: |
| 112 | + quantized_layers = get_quantized_linear_names(model_type) |
| 113 | + else: |
| 114 | + raise ValueError( |
| 115 | + "Could not determine model type to save quantized state dictionary." |
| 116 | + ) |
| 117 | + excluded_keys_from_new_sd = [ |
| 118 | + "calib_counter", |
| 119 | + "num_module_called", |
| 120 | + "smoothq_act_scale", |
| 121 | + "smoothq_alpha", |
| 122 | + "obsrv_w_clipval", |
| 123 | + "obsrv_clipval", |
| 124 | + "obsrv_clipvaln", |
| 125 | + ] |
| 126 | + |
| 127 | + new_sd = {} |
| 128 | + for k, v in model.state_dict().items(): |
| 129 | + if k.endswith(".weight") and any(qlayer in k for qlayer in quantized_layers): |
| 130 | + layername = k[:-7] |
| 131 | + |
| 132 | + # smoothquant processing: |
| 133 | + # - if smoothquant wasn't used, smoothq_alpha doesn't exist or is zero |
| 134 | + # - compute combined weight/activation smoothquant scaling factor (sq_scale) |
| 135 | + # - rescale weights before quantization |
| 136 | + # - store scaling factor into state dict |
| 137 | + v_scaled = None |
| 138 | + if layername + ".smoothq_alpha" in model.state_dict(): |
| 139 | + sq_a_scale = model.state_dict()[layername + ".smoothq_act_scale"] |
| 140 | + if sum(sq_a_scale) != 0: |
| 141 | + sq_alpha = model.state_dict()[layername + ".smoothq_alpha"] |
| 142 | + sq_w_scale = v.abs().max(dim=0, keepdim=True).values.clamp(min=1e-5) |
| 143 | + sq_scale = sq_a_scale.pow(sq_alpha) / sq_w_scale.pow(1 - sq_alpha) |
| 144 | + v_scaled = v * sq_scale # weights sq-scaled before quantization |
| 145 | + # guarding FP16 casting |
| 146 | + if sq_scale.abs().max() > torch.finfo(torch.float16).max: |
| 147 | + raise ValueError( |
| 148 | + "Quantization parameters (qscale) exceeds float16 range. " |
| 149 | + "Aborted state dict saving." |
| 150 | + ) |
| 151 | + new_sd[layername + ".smoothq_scale"] = ( |
| 152 | + sq_scale.squeeze().to(torch.float16).to("cpu") |
| 153 | + ) |
| 154 | + |
| 155 | + # quantize weights and store them into state dict |
| 156 | + if layername + ".quantize_weight.clip_val" in model.state_dict(): |
| 157 | + w_cv = model.state_dict()[layername + ".quantize_weight.clip_val"] |
| 158 | + if w_cv.numel() > 1: |
| 159 | + w_cv = w_cv.unsqueeze(dim=1) |
| 160 | + weight_pre_quant = v_scaled if v_scaled is not None else v |
| 161 | + weight_int = torch.clamp( |
| 162 | + 127 / w_cv * weight_pre_quant, -127, 127 |
| 163 | + ).round() |
| 164 | + new_sd[k] = weight_int.to(torch.int8).to("cpu") # signed int8 |
| 165 | + |
| 166 | + a_cv_name = layername + ".quantize_feature.clip_val" |
| 167 | + a_cvn_name = a_cv_name + "n" |
| 168 | + a_cv = None |
| 169 | + a_cvn = None |
| 170 | + if a_cv_name in model.state_dict(): |
| 171 | + a_cv = model.state_dict()[a_cv_name] |
| 172 | + if a_cvn_name in model.state_dict(): |
| 173 | + a_cvn = model.state_dict()[a_cvn_name] |
| 174 | + |
| 175 | + # compute "zero_shift" correction factor only for asymmetric activations |
| 176 | + if a_cv and a_cvn and a_cv != -a_cvn: |
| 177 | + if v.dim() == 2: |
| 178 | + # weight_int: [out_feat, in_feat] |
| 179 | + # sum (squash) along in_feat dimension: dim=1 |
| 180 | + new_sd[layername + ".zero_shift"] = ( |
| 181 | + torch.sum( |
| 182 | + weight_int, |
| 183 | + dim=1, |
| 184 | + ) |
| 185 | + .to(torch.float16) |
| 186 | + .to("cpu") |
| 187 | + ) |
| 188 | + else: |
| 189 | + raise NotImplementedError( |
| 190 | + "Zero shift computation for tensor " |
| 191 | + "with more than 2 dims is not supported yet." |
| 192 | + ) |
| 193 | + elif all(excluded_key not in k for excluded_key in excluded_keys_from_new_sd): |
| 194 | + # guarding FP16 cast |
| 195 | + if v.abs().max() > torch.finfo(torch.float16).max: |
| 196 | + raise ValueError( |
| 197 | + f"Quantization parameters ({k}) exceeds float16 range. " |
| 198 | + "Aborted state dict saving." |
| 199 | + ) |
| 200 | + new_sd[k] = v.to("cpu").to(torch.float16) |
| 201 | + |
| 202 | + logger.info("New state dict processed.") |
| 203 | + if verbose: |
| 204 | + logger.info( |
| 205 | + "\n" |
| 206 | + + "\n".join( |
| 207 | + f"{k:80} {str(list(v.size())):15} " |
| 208 | + f"{str(v.dtype):18} {str(v.device):10} " |
| 209 | + f"{v.reshape(-1)[0].item():12.4f} " |
| 210 | + f"{v.min().item():12.4f} {v.max().item():12.4f}" |
| 211 | + for k, v in new_sd.items() |
| 212 | + ) |
| 213 | + ) |
| 214 | + |
| 215 | + return new_sd |
| 216 | + |
| 217 | + |
| 218 | +def save_sd_for_aiu( |
| 219 | + model: PreTrainedModel, |
| 220 | + output_dir: str, |
| 221 | + savename: str = "qmodel_state_dict.pt", |
| 222 | + verbose: bool = False, |
| 223 | +) -> None: |
| 224 | + """Save model state dictionary after conversion for AIU compatibility.""" |
| 225 | + |
| 226 | + converted_sd = convert_sd_for_aiu(model, verbose) |
| 227 | + torch.save(converted_sd, Path(output_dir) / savename) |
| 228 | + logger.info("Model saved.") |
0 commit comments