Skip to content

Commit 7382019

Browse files
authored
IA3 adaptors (#403)
1 parent d498cf7 commit 7382019

File tree

4 files changed

+328
-0
lines changed

4 files changed

+328
-0
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
99

1010
### Added
1111

12+
- Adds a function to modify a Hugging Face transformer with IA3 adaptors
1213
- Added a `BeakerScheduler` registrable class, specified as the argument `scheduler` to `BeakerExecutor`, which controls the resources assigned to steps ran on Beaker.
1314
Users can implement their own `BeakerScheduler` subclasses to customize the resource assignment behavior.
1415

docs/source/api/integrations/transformers.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,3 +17,5 @@ Reference
1717

1818
.. autoclass:: tango.integrations.transformers.RunGenerationDataset
1919
:members:
20+
21+
.. autofunction:: tango.integrations.transformers.ia3.modify_with_ia3
Lines changed: 287 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,287 @@
1+
import re
2+
from dataclasses import dataclass
3+
from typing import Optional
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from transformers import PreTrainedModel
9+
from transformers.modeling_utils import Conv1D
10+
11+
12+
@dataclass
13+
class WithIA3Config:
14+
"""
15+
A class for configuring which layers to modify with IA3 adaptors.
16+
17+
18+
:param ia3_param_names:
19+
A string used as the name for all ia3 parameters
20+
:param attention_modules:
21+
A regex that matches all attention modules which are parents to the keys and value layers to modify.
22+
:param mlp_modules:
23+
A regex that matches all modules that are parents to the feed forward layer to modify.
24+
:param mlp_layers:
25+
A regex that matches the feed forward layer in the modules specified by `mlp_modles`.
26+
:param fused_qkv_layers:
27+
A regex that matches the combined query, key, and value layer in the modules specified
28+
by `attention_modules`.
29+
:param k_layers:
30+
A regex that matches the key layer in the modules specified by `attention_modules`.
31+
:param v_layers:
32+
A regex that matches the value layer in the modules specified by `attention_modules`.
33+
"""
34+
35+
ia3_param_names: str
36+
attention_modules: str
37+
mlp_modules: str
38+
mlp_layers: str
39+
fused_qkv_layers: Optional[str] = None
40+
k_layers: Optional[str] = None
41+
v_layers: Optional[str] = None
42+
43+
44+
GPT_J_IA3_CONFIG = WithIA3Config(
45+
attention_modules=".*attn",
46+
k_layers="k_proj",
47+
v_layers="v_proj",
48+
mlp_modules=".*mlp",
49+
mlp_layers="fc_in",
50+
ia3_param_names="ia3",
51+
)
52+
53+
GPT_2_IA3_CONFIG = WithIA3Config(
54+
attention_modules=".*attn",
55+
fused_qkv_layers="c_attn",
56+
mlp_modules=".*mlp",
57+
mlp_layers="c_fc",
58+
ia3_param_names="ia3",
59+
)
60+
61+
OPT_IA3_CONFIG = WithIA3Config(
62+
attention_modules=".*self_attn",
63+
k_layers="k_proj",
64+
v_layers="v_proj",
65+
mlp_modules=r".*layers\.\d*",
66+
mlp_layers="fc1",
67+
ia3_param_names="ia3",
68+
)
69+
70+
BLOOM_IA3_CONFIG = WithIA3Config(
71+
attention_modules=".*self_attention",
72+
fused_qkv_layers="query_key_value",
73+
mlp_modules=".*mlp",
74+
mlp_layers="dense_h_to_4h",
75+
ia3_param_names="ia3",
76+
)
77+
78+
MODEL_NAME_TO_CONFIG = {
79+
"sshleifer/tiny-gpt2": GPT_2_IA3_CONFIG,
80+
"gpt2": GPT_2_IA3_CONFIG,
81+
"gpt2-medium": GPT_2_IA3_CONFIG,
82+
"gpt2-large": GPT_2_IA3_CONFIG,
83+
"gpt2-xl": GPT_2_IA3_CONFIG,
84+
"bigscience/bloom-560m": BLOOM_IA3_CONFIG,
85+
"bigscience/bloom-1b1": BLOOM_IA3_CONFIG,
86+
"bigscience/bloom-1b7": BLOOM_IA3_CONFIG,
87+
"bigscience/bloom-3b": BLOOM_IA3_CONFIG,
88+
"bigscience/bloom-7b1": BLOOM_IA3_CONFIG,
89+
"bigscience/bloom": BLOOM_IA3_CONFIG,
90+
"facebook/opt-125m": OPT_IA3_CONFIG,
91+
"facebook/opt-350m": OPT_IA3_CONFIG,
92+
"facebook/opt-1.3b": OPT_IA3_CONFIG,
93+
"facebook/opt-2.7b": OPT_IA3_CONFIG,
94+
"facebook/opt-6.7b": OPT_IA3_CONFIG,
95+
"facebook/opt-13b": OPT_IA3_CONFIG,
96+
"facebook/opt-30b": OPT_IA3_CONFIG,
97+
"facebook/opt-66b": OPT_IA3_CONFIG,
98+
"EleutherAI/gpt-j-6B": GPT_J_IA3_CONFIG,
99+
}
100+
101+
102+
class WithIA3(nn.Module):
103+
def __init__(self, ia3_param_names: str, unfuse_size: int = None):
104+
super().__init__()
105+
self.ia3_param_names = ia3_param_names
106+
107+
# if (q,k,v) are stacked into one layer
108+
if unfuse_size is not None:
109+
# IA3 only operates on k and v (not q), thus the "* 2"
110+
setattr(self, ia3_param_names, nn.Parameter(torch.ones(unfuse_size * 2, 1)))
111+
else:
112+
setattr(self, ia3_param_names, nn.Parameter(torch.ones(self.out_features, 1))) # type: ignore
113+
114+
def scale_by_ia3(self, x):
115+
ia3_params = getattr(self, self.ia3_param_names)
116+
117+
if ia3_params.requires_grad:
118+
if self.unfuse_size is not None:
119+
# non_q means k and v
120+
q, non_q = x[:, :, : self.unfuse_size], x[:, :, self.unfuse_size :]
121+
ia3_params = getattr(self, self.ia3_param_names)
122+
non_q = non_q * ia3_params.flatten()
123+
x = torch.cat([q, non_q], dim=2)
124+
else:
125+
x = x * ia3_params.flatten()
126+
127+
return x
128+
129+
130+
class LinearWithIA3(WithIA3):
131+
def __init__(self, linear_layer: nn.Linear, ia3_param_names: str, unfuse_size: int = None):
132+
"""
133+
A replacement for :class:`~torch.nn.Linear` modified with an IA3 adaptor
134+
135+
136+
:param linear_layer:
137+
A :class:`~torch.nn.Linear` layer to adapt.
138+
:param ia3_param_names:
139+
A `str` to use as the name of ia3 parameters.
140+
:param unfuse_size:
141+
An `int` indicating hidden dimension of the query, key, and value vectors.
142+
To be used only when the layer to modify is a fused projection of query,
143+
key, and value vectors in an attention mechanism.
144+
"""
145+
assert unfuse_size is None or (linear_layer.out_features == unfuse_size * 3)
146+
self.in_features = linear_layer.in_features
147+
self.out_features = linear_layer.out_features
148+
self.unfuse_size = unfuse_size
149+
150+
super().__init__(ia3_param_names, unfuse_size)
151+
152+
self.weight = linear_layer.weight
153+
self.bias = linear_layer.bias
154+
155+
def forward(self, x):
156+
x = F.linear(x, self.weight, self.bias)
157+
return self.scale_by_ia3(x)
158+
159+
160+
class Conv1DWithIA3(WithIA3):
161+
def __init__(self, conv1d_layer: Conv1D, ia3_param_names: str, unfuse_size: int = None):
162+
"""
163+
A replacement for :class:`~transformers.modeling_utils.Conv1D` modified with an IA3 adaptor
164+
165+
166+
:param conv1d_layer:
167+
A :class:`~transformers.modeling_utils.Conv1D` layer to adapt.
168+
:param ia3_param_names:
169+
A `str` to use as the name of ia3 parameters.
170+
:param unfuse_size:
171+
An `int` indicating hidden dimension of the query, key, and value vectors.
172+
To be used only when the layer to modify is a fused projection of query,
173+
key, and value vectors in an attention mechanism.
174+
"""
175+
assert unfuse_size is None or (conv1d_layer.nf == unfuse_size * 3)
176+
177+
# nf: number of output features; nx: number of input features
178+
self.out_features = conv1d_layer.nf
179+
self.unfuse_size = unfuse_size
180+
181+
super().__init__(ia3_param_names, unfuse_size)
182+
183+
self.weight = conv1d_layer.weight
184+
self.bias = conv1d_layer.bias
185+
186+
def forward(self, x):
187+
# copied and pasted from the original Conv1D implemnetation
188+
size_out = x.size()[:-1] + (self.out_features,)
189+
x = torch.addmm(self.bias, x.view(-1, x.size(-1)), self.weight)
190+
x = x.view(size_out) # ... * self.nf
191+
192+
return self.scale_by_ia3(x)
193+
194+
195+
def modify_with_ia3(
196+
transformer: PreTrainedModel,
197+
*,
198+
config: WithIA3Config = None,
199+
only_ia3_requires_grad: bool = True,
200+
) -> PreTrainedModel:
201+
"""
202+
A function to add ia3 adaptors to the given transformer. Code modified from
203+
`t-few <https://github.com/r-three/t-few/blob/217cfa3b73aa66a07594826e4ebbbc516b331461/src/models/lora.py>`_
204+
and Qinyuan Ye
205+
206+
207+
:param model:
208+
A :class:`~transformers.PreTrainedModel` to modify.
209+
:param config:
210+
A :class:`~tango.integrations.transformers.ia3.WithIA3Config` that specifies the layers to modify.
211+
:param only_ia3_requires_grad:
212+
A `bool`, `True` if `requires_grad` should only be set on ia3 paramenters in the output model.
213+
214+
Examples
215+
--------
216+
217+
You can use the provided configurations:
218+
219+
.. testcode::
220+
221+
from transformers import AutoModelForCausalLM, AutoTokenizer
222+
from tango.integrations.transformers.ia3 import modify_with_ia3, GPT_2_IA3_CONFIG
223+
224+
model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
225+
model = modify_with_ia3(model, config=GPT_2_IA3_CONFIG)
226+
227+
Or you can write your own configuration with regex matching the layers to modify and their parents:
228+
229+
.. testcode::
230+
231+
from transformers import AutoModelForCausalLM, AutoTokenizer
232+
from tango.integrations.transformers.ia3 import modify_with_ia3
233+
234+
my_config = WithIA3Config(
235+
attention_modules=".*attn",
236+
fused_qkv_layers="c_attn",
237+
mlp_modules=".*mlp",
238+
mlp_layers="c_fc",
239+
ia3_param_names="ia3",
240+
)
241+
242+
model = AutoModelForCausalLM.from_pretrained("sshleifer/tiny-gpt2")
243+
model = modify_with_ia3(model, config=my_config)
244+
"""
245+
if config is None:
246+
model_name = transformer.config._name_or_path # type: ignore
247+
assert (
248+
model_name in MODEL_NAME_TO_CONFIG
249+
), f"{model_name} does not have a pre made configuration; please make your own."
250+
config = MODEL_NAME_TO_CONFIG[model_name]
251+
252+
for m_name, module in dict(transformer.named_modules()).items(): # type: ignore
253+
if re.fullmatch(config.attention_modules, m_name) or re.fullmatch(
254+
config.mlp_modules, m_name
255+
):
256+
attn_layers = [
257+
regex
258+
for regex in (config.fused_qkv_layers, config.k_layers, config.v_layers)
259+
if regex is not None
260+
]
261+
layers_to_change = (
262+
"|".join(attn_layers)
263+
if re.fullmatch(config.attention_modules, m_name)
264+
else config.mlp_layers
265+
)
266+
for c_name, layer in dict(module.named_children()).items():
267+
if re.fullmatch(layers_to_change, c_name):
268+
assert isinstance(layer, Conv1D) or isinstance(
269+
layer, nn.Linear
270+
), "This code only supports Conv1D and nn.Linear"
271+
adaptor_class = Conv1DWithIA3 if isinstance(layer, Conv1D) else LinearWithIA3
272+
new_module = adaptor_class(
273+
layer,
274+
config.ia3_param_names,
275+
unfuse_size=transformer.config.hidden_size # type: ignore
276+
if config.fused_qkv_layers and re.fullmatch(config.fused_qkv_layers, c_name)
277+
else None,
278+
)
279+
setattr(module, c_name, new_module)
280+
281+
if only_ia3_requires_grad:
282+
transformer.requires_grad_(False) # type: ignore
283+
for p_name, v in dict(transformer.named_parameters()).items(): # type: ignore
284+
if re.fullmatch(".*" + config.ia3_param_names + ".*", p_name):
285+
v.requires_grad_(True)
286+
287+
return transformer
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM, AutoTokenizer
3+
4+
from tango.integrations.transformers.ia3 import GPT_2_IA3_CONFIG, modify_with_ia3
5+
6+
7+
def test_ia3():
8+
9+
config = GPT_2_IA3_CONFIG
10+
model_name = "sshleifer/tiny-gpt2"
11+
12+
tokenizer = AutoTokenizer.from_pretrained(model_name)
13+
14+
input_seq = tokenizer(["A tiny test on a tiny model."], return_tensors="pt")
15+
16+
model = AutoModelForCausalLM.from_pretrained(model_name).eval()
17+
18+
with torch.inference_mode():
19+
old_outputs = model(
20+
input_ids=input_seq.input_ids,
21+
attention_mask=input_seq.attention_mask,
22+
labels=input_seq.input_ids,
23+
)
24+
25+
model = modify_with_ia3(model, config=config)
26+
27+
with torch.inference_mode():
28+
new_outputs = model(
29+
input_ids=input_seq.input_ids,
30+
attention_mask=input_seq.attention_mask,
31+
labels=input_seq.input_ids,
32+
)
33+
34+
logits_diff = torch.abs(old_outputs.logits - new_outputs.logits).mean()
35+
assert logits_diff < 1e-10
36+
37+
loss_diff = torch.abs(old_outputs.loss - new_outputs.loss)
38+
assert loss_diff < 1e-10

0 commit comments

Comments
 (0)