|
| 1 | +import re |
| 2 | +from dataclasses import dataclass |
| 3 | +from typing import Optional |
| 4 | + |
| 5 | +import torch |
| 6 | +import torch.nn as nn |
| 7 | +import torch.nn.functional as F |
| 8 | +from transformers import PreTrainedModel |
| 9 | +from transformers.modeling_utils import Conv1D |
| 10 | + |
| 11 | + |
| 12 | +@dataclass |
| 13 | +class WithIA3Config: |
| 14 | + """ |
| 15 | + A class for configuring which layers to modify with IA3 adaptors. |
| 16 | +
|
| 17 | +
|
| 18 | + :param ia3_param_names: |
| 19 | + A string used as the name for all ia3 parameters |
| 20 | + :param attention_modules: |
| 21 | + A regex that matches all attention modules which are parents to the keys and value layers to modify. |
| 22 | + :param mlp_modules: |
| 23 | + A regex that matches all modules that are parents to the feed forward layer to modify. |
| 24 | + :param mlp_layers: |
| 25 | + A regex that matches the feed forward layer in the modules specified by `mlp_modles`. |
| 26 | + :param fused_qkv_layers: |
| 27 | + A regex that matches the combined query, key, and value layer in the modules specified |
| 28 | + by `attention_modules`. |
| 29 | + :param k_layers: |
| 30 | + A regex that matches the key layer in the modules specified by `attention_modules`. |
| 31 | + :param v_layers: |
| 32 | + A regex that matches the value layer in the modules specified by `attention_modules`. |
| 33 | + """ |
| 34 | + |
| 35 | + ia3_param_names: str |
| 36 | + attention_modules: str |
| 37 | + mlp_modules: str |
| 38 | + mlp_layers: str |
| 39 | + fused_qkv_layers: Optional[str] = None |
| 40 | + k_layers: Optional[str] = None |
| 41 | + v_layers: Optional[str] = None |
| 42 | + |
| 43 | + |
| 44 | +GPT_J_IA3_CONFIG = WithIA3Config( |
| 45 | + attention_modules=".*attn", |
| 46 | + k_layers="k_proj", |
| 47 | + v_layers="v_proj", |
| 48 | + mlp_modules=".*mlp", |
| 49 | + mlp_layers="fc_in", |
| 50 | + ia3_param_names="ia3", |
| 51 | +) |
| 52 | + |
| 53 | +GPT_2_IA3_CONFIG = WithIA3Config( |
| 54 | + attention_modules=".*attn", |
| 55 | + fused_qkv_layers="c_attn", |
| 56 | + mlp_modules=".*mlp", |
| 57 | + mlp_layers="c_fc", |
| 58 | + ia3_param_names="ia3", |
| 59 | +) |
| 60 | + |
| 61 | +OPT_IA3_CONFIG = WithIA3Config( |
| 62 | + attention_modules=".*self_attn", |
| 63 | + k_layers="k_proj", |
| 64 | + v_layers="v_proj", |
| 65 | + mlp_modules=r".*layers\.\d*", |
| 66 | + mlp_layers="fc1", |
| 67 | + ia3_param_names="ia3", |
| 68 | +) |
| 69 | + |
| 70 | +BLOOM_IA3_CONFIG = WithIA3Config( |
| 71 | + attention_modules=".*self_attention", |
| 72 | + fused_qkv_layers="query_key_value", |
| 73 | + mlp_modules=".*mlp", |
| 74 | + mlp_layers="dense_h_to_4h", |
| 75 | + ia3_param_names="ia3", |
| 76 | +) |
| 77 | + |
| 78 | +MODEL_NAME_TO_CONFIG = { |
| 79 | + "sshleifer/tiny-gpt2": GPT_2_IA3_CONFIG, |
| 80 | + "gpt2": GPT_2_IA3_CONFIG, |
| 81 | + "gpt2-medium": GPT_2_IA3_CONFIG, |
| 82 | + "gpt2-large": GPT_2_IA3_CONFIG, |
| 83 | + "gpt2-xl": GPT_2_IA3_CONFIG, |
| 84 | + "bigscience/bloom-560m": BLOOM_IA3_CONFIG, |
| 85 | + "bigscience/bloom-1b1": BLOOM_IA3_CONFIG, |
| 86 | + "bigscience/bloom-1b7": BLOOM_IA3_CONFIG, |
| 87 | + "bigscience/bloom-3b": BLOOM_IA3_CONFIG, |
| 88 | + "bigscience/bloom-7b1": BLOOM_IA3_CONFIG, |
| 89 | + "bigscience/bloom": BLOOM_IA3_CONFIG, |
| 90 | + "facebook/opt-125m": OPT_IA3_CONFIG, |
| 91 | + "facebook/opt-350m": OPT_IA3_CONFIG, |
| 92 | + "facebook/opt-1.3b": OPT_IA3_CONFIG, |
| 93 | + "facebook/opt-2.7b": OPT_IA3_CONFIG, |
| 94 | + "facebook/opt-6.7b": OPT_IA3_CONFIG, |
| 95 | + "facebook/opt-13b": OPT_IA3_CONFIG, |
| 96 | + "facebook/opt-30b": OPT_IA3_CONFIG, |
| 97 | + "facebook/opt-66b": OPT_IA3_CONFIG, |
| 98 | + "EleutherAI/gpt-j-6B": GPT_J_IA3_CONFIG, |
| 99 | +} |
| 100 | + |
| 101 | + |
| 102 | +class WithIA3(nn.Module): |
| 103 | + def __init__(self, ia3_param_names: str, unfuse_size: int = None): |
| 104 | + super().__init__() |
| 105 | + self.ia3_param_names = ia3_param_names |
| 106 | + |
| 107 | + # if (q,k,v) are stacked into one layer |
| 108 | + if unfuse_size is not None: |
| 109 | + # IA3 only operates on k and v (not q), thus the "* 2" |
| 110 | + setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1))) |
| 111 | + else: |
| 112 | + setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) # type: ignore |
| 113 | + |
| 114 | + def scale_by_ia3(self, x): |
| 115 | + ia3_params = getattr(self, self.ia3_param_names) |
| 116 | + |
| 117 | + if ia3_params.requires_grad: |
| 118 | + if self.unfuse_size is not None: |
| 119 | + # non_q means k and v |
| 120 | + q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :] |
| 121 | + ia3_params = getattr(self, self.ia3_param_names) |
| 122 | + non_q = non_q * ia3_params.flatten() |
| 123 | + x = torch.cat([q, non_q], dim=2) |
| 124 | + else: |
| 125 | + x = x * ia3_params.flatten() |
| 126 | + |
| 127 | + return x |
| 128 | + |
| 129 | + |
| 130 | +class LinearWithIA3(WithIA3): |
| 131 | + def __init__(self, linear_layer: nn.Linear, ia3_param_names: str, unfuse_size: int = None): |
| 132 | + """ |
| 133 | + A replacement for :class:`~torch.nn.Linear` modified with an IA3 adaptor |
| 134 | +
|
| 135 | +
|
| 136 | + :param linear_layer: |
| 137 | + A :class:`~torch.nn.Linear` layer to adapt. |
| 138 | + :param ia3_param_names: |
| 139 | + A `str` to use as the name of ia3 parameters. |
| 140 | + :param unfuse_size: |
| 141 | + An `int` indicating hidden dimension of the query, key, and value vectors. |
| 142 | + To be used only when the layer to modify is a fused projection of query, |
| 143 | + key, and value vectors in an attention mechanism. |
| 144 | + """ |
| 145 | + assert unfuse_size is None or (linear_layer.out_features == unfuse_size * 3) |
| 146 | + self.in_features = linear_layer.in_features |
| 147 | + self.out_features = linear_layer.out_features |
| 148 | + self.unfuse_size = unfuse_size |
| 149 | + |
| 150 | + super().__init__(ia3_param_names, unfuse_size) |
| 151 | + |
| 152 | + self.weight = linear_layer.weight |
| 153 | + self.bias = linear_layer.bias |
| 154 | + |
| 155 | + def forward(self, x): |
| 156 | + x = F.linear(x, self.weight, self.bias) |
| 157 | + return self.scale_by_ia3(x) |
| 158 | + |
| 159 | + |
| 160 | +class Conv1DWithIA3(WithIA3): |
| 161 | + def __init__(self, conv1d_layer: Conv1D, ia3_param_names: str, unfuse_size: int = None): |
| 162 | + """ |
| 163 | + A replacement for :class:`~transformers.modeling_utils.Conv1D` modified with an IA3 adaptor |
| 164 | +
|
| 165 | +
|
| 166 | + :param conv1d_layer: |
| 167 | + A :class:`~transformers.modeling_utils.Conv1D` layer to adapt. |
| 168 | + :param ia3_param_names: |
| 169 | + A `str` to use as the name of ia3 parameters. |
| 170 | + :param unfuse_size: |
| 171 | + An `int` indicating hidden dimension of the query, key, and value vectors. |
| 172 | + To be used only when the layer to modify is a fused projection of query, |
| 173 | + key, and value vectors in an attention mechanism. |
| 174 | + """ |
| 175 | + assert unfuse_size is None or (conv1d_layer.nf == unfuse_size * 3) |
| 176 | + |
| 177 | + # nf: number of output features; nx: number of input features |
| 178 | + self.out_features = conv1d_layer.nf |
| 179 | + self.unfuse_size = unfuse_size |
| 180 | + |
| 181 | + super().__init__(ia3_param_names, unfuse_size) |
| 182 | + |
| 183 | + self.weight = conv1d_layer.weight |
| 184 | + self.bias = conv1d_layer.bias |
| 185 | + |
| 186 | + def forward(self, x): |
| 187 | + # copied and pasted from the original Conv1D implemnetation |
| 188 | + size_out = x.size()[:-1] + (self.out_features,) |
| 189 | + x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight) |
| 190 | + x = x.view(size_out) # ... * self.nf |
| 191 | + |
| 192 | + return self.scale_by_ia3(x) |
| 193 | + |
| 194 | + |
| 195 | +def modify_with_ia3( |
| 196 | + transformer: PreTrainedModel, |
| 197 | + *, |
| 198 | + config: WithIA3Config = None, |
| 199 | + only_ia3_requires_grad: bool = True, |
| 200 | +) -> PreTrainedModel: |
| 201 | + """ |
| 202 | + A function to add ia3 adaptors to the given transformer. Code modified from |
| 203 | + `t-few <https://github.com/r-three/t-few/blob/217cfa3b73aa66a07594826e4ebbbc516b331461/src/models/lora.py>`_ |
| 204 | + and Qinyuan Ye |
| 205 | +
|
| 206 | +
|
| 207 | + :param model: |
| 208 | + A :class:`~transformers.PreTrainedModel` to modify. |
| 209 | + :param config: |
| 210 | + A :class:`~tango.integrations.transformers.ia3.WithIA3Config` that specifies the layers to modify. |
| 211 | + :param only_ia3_requires_grad: |
| 212 | + A `bool`, `True` if `requires_grad` should only be set on ia3 paramenters in the output model. |
| 213 | +
|
| 214 | + Examples |
| 215 | + -------- |
| 216 | +
|
| 217 | + You can use the provided configurations: |
| 218 | +
|
| 219 | + .. testcode:: |
| 220 | +
|
| 221 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 222 | + from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG |
| 223 | +
|
| 224 | + model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") |
| 225 | + model = modify_with_ia3(model, config=GPT_2_IA3_CONFIG) |
| 226 | +
|
| 227 | + Or you can write your own configuration with regex matching the layers to modify and their parents: |
| 228 | +
|
| 229 | + .. testcode:: |
| 230 | +
|
| 231 | + from transformers import AutoModelForCausalLM, AutoTokenizer |
| 232 | + from tango.integrations.transformers.ia3 import modify_with_ia3 |
| 233 | +
|
| 234 | + my_config = WithIA3Config( |
| 235 | + attention_modules=".*attn", |
| 236 | + fused_qkv_layers="c_attn", |
| 237 | + mlp_modules=".*mlp", |
| 238 | + mlp_layers="c_fc", |
| 239 | + ia3_param_names="ia3", |
| 240 | + ) |
| 241 | +
|
| 242 | + model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2") |
| 243 | + model = modify_with_ia3(model, config=my_config) |
| 244 | + """ |
| 245 | + if config is None: |
| 246 | + model_name = transformer.config._name_or_path # type: ignore |
| 247 | + assert ( |
| 248 | + model_name in MODEL_NAME_TO_CONFIG |
| 249 | + ), f"{model_name} does not have a pre made configuration; please make your own." |
| 250 | + config = MODEL_NAME_TO_CONFIG[model_name] |
| 251 | + |
| 252 | + for m_name, module in dict(transformer.named_modules()).items(): # type: ignore |
| 253 | + if re.fullmatch(config.attention_modules, m_name) or re.fullmatch( |
| 254 | + config.mlp_modules, m_name |
| 255 | + ): |
| 256 | + attn_layers = [ |
| 257 | + regex |
| 258 | + for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers) |
| 259 | + if regex is not None |
| 260 | + ] |
| 261 | + layers_to_change = ( |
| 262 | + "|".join(attn_layers) |
| 263 | + if re.fullmatch(config.attention_modules, m_name) |
| 264 | + else config.mlp_layers |
| 265 | + ) |
| 266 | + for c_name, layer in dict(module.named_children()).items(): |
| 267 | + if re.fullmatch(layers_to_change, c_name): |
| 268 | + assert isinstance(layer, Conv1D) or isinstance( |
| 269 | + layer, nn.Linear |
| 270 | + ), "This code only supports Conv1D and nn.Linear" |
| 271 | + adaptor_class = Conv1DWithIA3 if isinstance(layer, Conv1D) else LinearWithIA3 |
| 272 | + new_module = adaptor_class( |
| 273 | + layer, |
| 274 | + config.ia3_param_names, |
| 275 | + unfuse_size=transformer.config.hidden_size # type: ignore |
| 276 | + if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name) |
| 277 | + else None, |
| 278 | + ) |
| 279 | + setattr(module, c_name, new_module) |
| 280 | + |
| 281 | + if only_ia3_requires_grad: |
| 282 | + transformer.requires_grad_(False) # type: ignore |
| 283 | + for p_name, v in dict(transformer.named_parameters()).items(): # type: ignore |
| 284 | + if re.fullmatch(".*" + config.ia3_param_names + ".*", p_name): |
| 285 | + v.requires_grad_(True) |
| 286 | + |
| 287 | + return transformer |
0 commit comments