Skip to content

Commit 663b120

Browse files
authored
ENH Llama-Adapters support for GPT2 (#2643)
aka "adaption prompt"
1 parent 04a5ed7 commit 663b120

File tree

5 files changed

+156
-16
lines changed

5 files changed

+156
-16
lines changed

src/peft/tuners/adaption_prompt/config.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from peft.config import PeftConfig
1919
from peft.utils import PeftType
2020

21-
from .utils import llama_compute_query_states
21+
from .utils import gpt2_compute_query_states, llama_compute_query_states
2222

2323

2424
@dataclass
@@ -62,6 +62,13 @@ def is_adaption_prompt(self) -> bool:
6262
v_proj_layer="v_proj",
6363
o_proj_layer="o_proj",
6464
),
65+
"gpt2": ModelTypeConfig( # piggybacking of off the prior definitions, GPTs attention calculation is different
66+
compute_query_states=gpt2_compute_query_states,
67+
target_modules="attn",
68+
k_proj_layer="c_attn",
69+
v_proj_layer=None,
70+
o_proj_layer=None,
71+
),
6572
}
6673

6774

@@ -71,7 +78,7 @@ def prepare_config(
7178
) -> AdaptionPromptConfig:
7279
"""Prepare the config based on the llama model type."""
7380
if model.config.model_type not in TRANSFORMERS_MODEL_CONFIG:
74-
raise ValueError("Unsupported model type for adaption prompt: '{model.config.model_type}'.")
81+
raise ValueError(f"Unsupported model type for adaption prompt: '{model.config.model_type}'.")
7582

7683
model_config = TRANSFORMERS_MODEL_CONFIG[model.config.model_type]
7784

src/peft/tuners/adaption_prompt/layer.py

Lines changed: 105 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import math
16+
from typing import Optional, Union
1617

1718
import torch
1819
import torch.nn as nn
@@ -21,10 +22,10 @@
2122
from .config import TRANSFORMERS_MODEL_CONFIG
2223

2324

24-
class AdaptedAttention(nn.Module):
25-
"""This module wraps a LLamaAttention module and injects adaption prompts."""
25+
class _BaseAdaptedAttention(nn.Module):
26+
"""Base module, which defines adaption prompts for multiple model types."""
2627

27-
def __init__(self, model_type: str, adapter_len: int, model):
28+
def __init__(self, model_type: str, adapter_len: int, model, target_dtype=torch.float32):
2829
"""
2930
Initialize object.
3031
@@ -34,31 +35,128 @@ def __init__(self, model_type: str, adapter_len: int, model):
3435
adapter_len: The length of the adaption prompt to insert.
3536
model: The original transformer attention module that is being wrapped.
3637
"""
37-
assert not isinstance(model, AdaptedAttention)
38+
if isinstance(model, _BaseAdaptedAttention):
39+
raise ValueError("Unable to stack multiple adaption prompts")
3840
super().__init__()
3941
self.model_type = model_type
4042
self.model = model
4143
self.adapter_len = adapter_len
4244
# Assume all parameters of the attention model we are wrapping are on the same device.
45+
4346
device = next(model.parameters()).device
4447
# Don't think this was specified in the paper, but we follow the official repo which used an Embedding
4548
# which initializes the tokens with standard normal values.
4649
# https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L234
4750
# (bsz, adapter_len, hidden_size)
48-
target_dtype = (
49-
model.q_proj.weight.dtype if model.q_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32
50-
)
51+
5152
if hasattr(self.model, "hidden_size"):
5253
# TODO: remove this clause after 2026-01-01
5354
hidden_size = self.model.hidden_size
5455
else: # changed in https://github.com/huggingface/transformers/pull/35235
5556
hidden_size = self.model.config.hidden_size
57+
58+
if hasattr(self.model, "num_heads"):
59+
# TODO: remove this clause after 2026-01-01
60+
self.num_heads = self.model.num_heads
61+
else: # changed in https://github.com/huggingface/transformers/pull/35235
62+
self.num_heads = self.model.config.num_attention_heads
63+
5664
self.adaption_prompt = nn.Parameter(
5765
torch.empty(1, adapter_len, hidden_size, device=device, dtype=target_dtype).normal_()
5866
)
5967
# Initialize the gate to 0 as this is "zero-init".
6068
self.adaption_gate = nn.Parameter(torch.zeros(1, device=device, dtype=target_dtype))
6169

70+
71+
class AdaptedAttentionGPT(_BaseAdaptedAttention):
72+
"""This module wraps a GPT2Attention module and injects adaption prompts"""
73+
74+
def __init__(self, model_type, adapter_len, model):
75+
target_dtype = (
76+
model.c_proj.weight.dtype if model.c_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32
77+
)
78+
super().__init__(model_type, adapter_len, model, target_dtype=target_dtype)
79+
80+
def forward(
81+
self,
82+
hidden_states: Optional[tuple[torch.FloatTensor]],
83+
layer_past: Optional[tuple[torch.Tensor]] = None,
84+
attention_mask: Optional[torch.FloatTensor] = None,
85+
head_mask: Optional[torch.FloatTensor] = None,
86+
encoder_hidden_states: Optional[torch.Tensor] = None,
87+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
88+
use_cache: Optional[bool] = False,
89+
output_attentions: Optional[bool] = False,
90+
**kwargs,
91+
) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
92+
attn_outputs = self.model(
93+
hidden_states=hidden_states,
94+
attention_mask=attention_mask,
95+
head_mask=head_mask,
96+
encoder_hidden_states=encoder_hidden_states,
97+
encoder_attention_mask=encoder_attention_mask,
98+
use_cache=use_cache,
99+
output_attentions=output_attentions,
100+
**kwargs,
101+
)
102+
"""
103+
Forward pass for the adapter which wraps the GPT2Attention module
104+
"""
105+
106+
attn_output = attn_outputs[0]
107+
add_outputs = attn_outputs[1:]
108+
109+
c_attn_layer = TRANSFORMERS_MODEL_CONFIG[self.model_type].k_proj_layer
110+
111+
bsz = attn_output.shape[0]
112+
q_len = attn_output.shape[1]
113+
embed_dim = attn_output.shape[2]
114+
115+
_, key, value = getattr(self.model, c_attn_layer)(self.adaption_prompt).split(embed_dim, dim=2)
116+
117+
adapter_k = (
118+
key.view(1, self.adapter_len, self.num_heads, self.model.head_dim).repeat(bsz, 1, 1, 1).transpose(1, 2)
119+
)
120+
adapter_v = (
121+
value.view(1, self.adapter_len, self.num_heads, self.model.head_dim).repeat(bsz, 1, 1, 1).transpose(1, 2)
122+
)
123+
# recompute query state since it is not returned by GPT2 forward
124+
compute_query_states = TRANSFORMERS_MODEL_CONFIG[self.model_type].compute_query_states
125+
query_states = compute_query_states(
126+
self.model, hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states
127+
)
128+
129+
previous_dtype = query_states.dtype
130+
131+
scores = torch.matmul(query_states, adapter_k.transpose(2, 3).to(previous_dtype)) / math.sqrt(
132+
self.model.head_dim
133+
)
134+
# Upcast attention to fp32
135+
# (bsz, num_heads, q_len, adapter_len)
136+
scores = self.adaption_gate * F.softmax(scores, dim=-1, dtype=torch.float32).to(previous_dtype)
137+
# (bsz, q_len, num_heads * head_dim)
138+
adapter_output = torch.matmul(scores, adapter_v).transpose(1, 2).reshape(bsz, q_len, -1)
139+
140+
# Add adaption prompt output to original output.
141+
hidden_state = attn_output + adapter_output
142+
143+
# Restore original dtype.
144+
hidden_state = hidden_state.to(previous_dtype)
145+
146+
# add additional attention outputs (attention and cross attention)
147+
output = (hidden_state,) + add_outputs
148+
return output
149+
150+
151+
class AdaptedAttention(_BaseAdaptedAttention):
152+
"""This module wraps a LLamaAttention module and injects adaption prompts."""
153+
154+
def __init__(self, model_type, adapter_len, model):
155+
target_dtype = (
156+
model.q_proj.weight.dtype if model.q_proj.weight.dtype not in [torch.int8, torch.uint8] else torch.float32
157+
)
158+
super().__init__(model_type, adapter_len, model, target_dtype=target_dtype)
159+
62160
def forward(self, **kwargs):
63161
"""
64162
Forward pass for the adapter which wraps the original LlamaAttention module.

src/peft/tuners/adaption_prompt/model.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from peft.utils import _freeze_adapter, _get_submodules
1919

2020
from .config import AdaptionPromptConfig, prepare_config
21-
from .layer import AdaptedAttention
21+
from .layer import AdaptedAttention, AdaptedAttentionGPT
2222
from .utils import is_adaption_prompt_trainable
2323

2424

@@ -65,7 +65,7 @@ def add_adapter(self, adapter_name: str, config: AdaptionPromptConfig) -> None:
6565

6666
parents = []
6767
for name, _ in self.model.named_modules():
68-
if name.endswith(config.target_modules):
68+
if name.endswith(f".{config.target_modules}"):
6969
par, _, _ = _get_submodules(self.model, name)
7070
parents.append(par)
7171
if len(parents) < config.adapter_layers:
@@ -118,11 +118,19 @@ def disable_adapter_layers(self):
118118
def _create_adapted_attentions(self, config: AdaptionPromptConfig, parents: list[nn.Module]) -> None:
119119
"""Wrap LlamaAttention modules with newly created AdaptedAttention modules."""
120120
for par in parents:
121-
attn = AdaptedAttention(
122-
model_type=self.model.config.model_type,
123-
adapter_len=config.adapter_len,
124-
model=getattr(par, config.target_modules),
125-
)
121+
if self.model.config.model_type == "gpt2":
122+
attn = AdaptedAttentionGPT(
123+
model_type=self.model.config.model_type,
124+
adapter_len=config.adapter_len,
125+
model=getattr(par, config.target_modules),
126+
)
127+
128+
else:
129+
attn = AdaptedAttention(
130+
model_type=self.model.config.model_type,
131+
adapter_len=config.adapter_len,
132+
model=getattr(par, config.target_modules),
133+
)
126134
setattr(par, config.target_modules, attn)
127135

128136
def _set_adapted_attentions(self, adapter_name: str) -> None:

src/peft/tuners/adaption_prompt/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import inspect
15+
from typing import Optional
1516

1617
import torch
1718
import torch.nn as nn
@@ -127,6 +128,31 @@ def llama_compute_query_states(model: nn.Module, **kwargs) -> torch.Tensor:
127128
return (query_states * cos) + (llama_rotate_half(query_states) * sin)
128129

129130

131+
def gpt2_compute_query_states(
132+
model: nn.Module,
133+
hidden_states: Optional[tuple[torch.FloatTensor]],
134+
encoder_hidden_states: Optional[torch.Tensor] = None,
135+
) -> torch.Tensor:
136+
"""
137+
Compute query states for GPT2 models. They need to be recomputed as the forward() method of the GPT@ in the
138+
transformers library does not return them. See the related discussion in the PR:
139+
"""
140+
if encoder_hidden_states is not None:
141+
if not hasattr(model, "q_attn"):
142+
raise ValueError(
143+
f"If `{model.__class__.__name__}` is used as cross attention, the weights `q_attn` must be defined. "
144+
f"Please make sure to instantiate it with `GPT2Attention(..., is_cross_attention=True)`."
145+
)
146+
query_states = model.q_attn(hidden_states)
147+
else:
148+
query_states, _, _ = model.c_attn(hidden_states).split(model.split_size, dim=2)
149+
150+
shape_q = (*query_states.shape[:-1], -1, model.head_dim)
151+
query_states = query_states.view(shape_q).transpose(1, 2)
152+
153+
return query_states
154+
155+
130156
def is_adaption_prompt_trainable(params: str) -> bool:
131157
"""Return True if module is trainable under adaption prompt fine-tuning."""
132158
return params.split(".")[-1].startswith("adaption_")

tests/test_adaption_prompt.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030

3131
MODELS_TO_TEST = [
32+
"hf-internal-testing/tiny-random-gpt2",
3233
"trl-internal-testing/tiny-random-LlamaForCausalLM",
3334
"hf-internal-testing/tiny-random-MistralForCausalLM",
3435
]

0 commit comments

Comments
 (0)