1313# limitations under the License.
1414
1515import math
16+ from typing import Optional , Union
1617
1718import torch
1819import torch .nn as nn
2122from .config import TRANSFORMERS_MODEL_CONFIG
2223
2324
24- class AdaptedAttention (nn .Module ):
25- """This module wraps a LLamaAttention module and injects adaption prompts ."""
25+ class _BaseAdaptedAttention (nn .Module ):
26+ """Base module, which defines adaption prompts for multiple model types ."""
2627
27- def __init__ (self , model_type : str , adapter_len : int , model ):
28+ def __init__ (self , model_type : str , adapter_len : int , model , target_dtype = torch . float32 ):
2829 """
2930 Initialize object.
3031
@@ -34,31 +35,128 @@ def __init__(self, model_type: str, adapter_len: int, model):
3435 adapter_len: The length of the adaption prompt to insert.
3536 model: The original transformer attention module that is being wrapped.
3637 """
37- assert not isinstance (model , AdaptedAttention )
38+ if isinstance (model , _BaseAdaptedAttention ):
39+ raise ValueError ("Unable to stack multiple adaption prompts" )
3840 super ().__init__ ()
3941 self .model_type = model_type
4042 self .model = model
4143 self .adapter_len = adapter_len
4244 # Assume all parameters of the attention model we are wrapping are on the same device.
45+
4346 device = next (model .parameters ()).device
4447 # Don't think this was specified in the paper, but we follow the official repo which used an Embedding
4548 # which initializes the tokens with standard normal values.
4649 # https://github.com/ZrrSkywalker/LLaMA-Adapter/blob/41c3546fe1997ab8a65809dc8d8f9252b19d9faf/llama/model.py#L234
4750 # (bsz, adapter_len, hidden_size)
48- target_dtype = (
49- model .q_proj .weight .dtype if model .q_proj .weight .dtype not in [torch .int8 , torch .uint8 ] else torch .float32
50- )
51+
5152 if hasattr (self .model , "hidden_size" ):
5253 # TODO: remove this clause after 2026-01-01
5354 hidden_size = self .model .hidden_size
5455 else : # changed in https://github.com/huggingface/transformers/pull/35235
5556 hidden_size = self .model .config .hidden_size
57+
58+ if hasattr (self .model , "num_heads" ):
59+ # TODO: remove this clause after 2026-01-01
60+ self .num_heads = self .model .num_heads
61+ else : # changed in https://github.com/huggingface/transformers/pull/35235
62+ self .num_heads = self .model .config .num_attention_heads
63+
5664 self .adaption_prompt = nn .Parameter (
5765 torch .empty (1 , adapter_len , hidden_size , device = device , dtype = target_dtype ).normal_ ()
5866 )
5967 # Initialize the gate to 0 as this is "zero-init".
6068 self .adaption_gate = nn .Parameter (torch .zeros (1 , device = device , dtype = target_dtype ))
6169
70+
71+ class AdaptedAttentionGPT (_BaseAdaptedAttention ):
72+ """This module wraps a GPT2Attention module and injects adaption prompts"""
73+
74+ def __init__ (self , model_type , adapter_len , model ):
75+ target_dtype = (
76+ model .c_proj .weight .dtype if model .c_proj .weight .dtype not in [torch .int8 , torch .uint8 ] else torch .float32
77+ )
78+ super ().__init__ (model_type , adapter_len , model , target_dtype = target_dtype )
79+
80+ def forward (
81+ self ,
82+ hidden_states : Optional [tuple [torch .FloatTensor ]],
83+ layer_past : Optional [tuple [torch .Tensor ]] = None ,
84+ attention_mask : Optional [torch .FloatTensor ] = None ,
85+ head_mask : Optional [torch .FloatTensor ] = None ,
86+ encoder_hidden_states : Optional [torch .Tensor ] = None ,
87+ encoder_attention_mask : Optional [torch .FloatTensor ] = None ,
88+ use_cache : Optional [bool ] = False ,
89+ output_attentions : Optional [bool ] = False ,
90+ ** kwargs ,
91+ ) -> tuple [Union [torch .Tensor , tuple [torch .Tensor ]], ...]:
92+ attn_outputs = self .model (
93+ hidden_states = hidden_states ,
94+ attention_mask = attention_mask ,
95+ head_mask = head_mask ,
96+ encoder_hidden_states = encoder_hidden_states ,
97+ encoder_attention_mask = encoder_attention_mask ,
98+ use_cache = use_cache ,
99+ output_attentions = output_attentions ,
100+ ** kwargs ,
101+ )
102+ """
103+ Forward pass for the adapter which wraps the GPT2Attention module
104+ """
105+
106+ attn_output = attn_outputs [0 ]
107+ add_outputs = attn_outputs [1 :]
108+
109+ c_attn_layer = TRANSFORMERS_MODEL_CONFIG [self .model_type ].k_proj_layer
110+
111+ bsz = attn_output .shape [0 ]
112+ q_len = attn_output .shape [1 ]
113+ embed_dim = attn_output .shape [2 ]
114+
115+ _ , key , value = getattr (self .model , c_attn_layer )(self .adaption_prompt ).split (embed_dim , dim = 2 )
116+
117+ adapter_k = (
118+ key .view (1 , self .adapter_len , self .num_heads , self .model .head_dim ).repeat (bsz , 1 , 1 , 1 ).transpose (1 , 2 )
119+ )
120+ adapter_v = (
121+ value .view (1 , self .adapter_len , self .num_heads , self .model .head_dim ).repeat (bsz , 1 , 1 , 1 ).transpose (1 , 2 )
122+ )
123+ # recompute query state since it is not returned by GPT2 forward
124+ compute_query_states = TRANSFORMERS_MODEL_CONFIG [self .model_type ].compute_query_states
125+ query_states = compute_query_states (
126+ self .model , hidden_states = hidden_states , encoder_hidden_states = encoder_hidden_states
127+ )
128+
129+ previous_dtype = query_states .dtype
130+
131+ scores = torch .matmul (query_states , adapter_k .transpose (2 , 3 ).to (previous_dtype )) / math .sqrt (
132+ self .model .head_dim
133+ )
134+ # Upcast attention to fp32
135+ # (bsz, num_heads, q_len, adapter_len)
136+ scores = self .adaption_gate * F .softmax (scores , dim = - 1 , dtype = torch .float32 ).to (previous_dtype )
137+ # (bsz, q_len, num_heads * head_dim)
138+ adapter_output = torch .matmul (scores , adapter_v ).transpose (1 , 2 ).reshape (bsz , q_len , - 1 )
139+
140+ # Add adaption prompt output to original output.
141+ hidden_state = attn_output + adapter_output
142+
143+ # Restore original dtype.
144+ hidden_state = hidden_state .to (previous_dtype )
145+
146+ # add additional attention outputs (attention and cross attention)
147+ output = (hidden_state ,) + add_outputs
148+ return output
149+
150+
151+ class AdaptedAttention (_BaseAdaptedAttention ):
152+ """This module wraps a LLamaAttention module and injects adaption prompts."""
153+
154+ def __init__ (self , model_type , adapter_len , model ):
155+ target_dtype = (
156+ model .q_proj .weight .dtype if model .q_proj .weight .dtype not in [torch .int8 , torch .uint8 ] else torch .float32
157+ )
158+ super ().__init__ (model_type , adapter_len , model , target_dtype = target_dtype )
159+
62160 def forward (self , ** kwargs ):
63161 """
64162 Forward pass for the adapter which wraps the original LlamaAttention module.
0 commit comments