|
23 | 23 |
|
24 | 24 | LOG = get_logger(__name__)
|
25 | 25 |
|
26 |
| -ORIGINAL_QKV_CODE = """ |
| 26 | +QKV_PATCHES = [ |
| 27 | + ( |
| 28 | + """ |
27 | 29 | query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
28 | 30 | key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
29 | 31 | value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
30 | 32 | """.lstrip(
|
31 |
| - "\n" |
32 |
| -) |
33 |
| - |
34 |
| -PATCHED_QKV_CODE = """ |
| 33 | + "\n" |
| 34 | + ), |
| 35 | + """ |
35 | 36 | query_states, key_states, value_states = self.apply_qkv(hidden_states)
|
36 | 37 | query_states = query_states.view(hidden_shape).transpose(1, 2)
|
37 | 38 | key_states = key_states.view(hidden_shape).transpose(1, 2)
|
38 | 39 | value_states = value_states.view(hidden_shape).transpose(1, 2)
|
39 | 40 | """.lstrip(
|
40 |
| - "\n" |
41 |
| -) |
| 41 | + "\n" |
| 42 | + ), |
| 43 | + ), |
| 44 | + ( |
| 45 | + """ |
| 46 | + query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| 47 | + key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2) |
| 48 | + value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2) |
| 49 | +""".lstrip( |
| 50 | + "\n" |
| 51 | + ), |
| 52 | + """ |
| 53 | + query_states, key_states, value_states = self.apply_qkv(hidden_states) |
| 54 | + query_states = self.q_norm(query_states.view(hidden_shape)).transpose(1, 2) |
| 55 | + key_states = self.k_norm(key_states.view(hidden_shape)).transpose(1, 2) |
| 56 | + value_states = value_states.view(hidden_shape).transpose(1, 2) |
| 57 | +""".lstrip( |
| 58 | + "\n" |
| 59 | + ), |
| 60 | + ), |
| 61 | +] |
42 | 62 |
|
43 | 63 | ORIGINAL_O_CODE = """
|
44 | 64 | attn_output = self.o_proj(attn_output)
|
@@ -128,10 +148,11 @@ def get_attention_cls_from_config(cfg: DictDefault) -> Type[nn.Module]:
|
128 | 148 | try:
|
129 | 149 | # Dynamically import the module and attention class
|
130 | 150 | module_path = f"transformers.models.{model_type}.modeling_{model_type}"
|
131 |
| - module = __import__( |
132 |
| - module_path, fromlist=[f"{model_type.capitalize()}Attention"] |
| 151 | + model_cls_prefix = "".join( |
| 152 | + [part.capitalize() for part in model_type.split("_")] |
133 | 153 | )
|
134 |
| - attention_cls = getattr(module, f"{model_type.capitalize()}Attention") |
| 154 | + module = __import__(module_path, fromlist=[f"{model_cls_prefix}Attention"]) |
| 155 | + attention_cls = getattr(module, f"{model_cls_prefix}Attention") |
135 | 156 |
|
136 | 157 | return attention_cls
|
137 | 158 | except (ImportError, AttributeError) as e:
|
@@ -168,10 +189,18 @@ def patch_self_attn_lora(cfg: DictDefault):
|
168 | 189 | attention_cls._original_forward = self_attn_forward
|
169 | 190 | self_attn_forward, _ = detab_code(self_attn_forward)
|
170 | 191 |
|
171 |
| - assert ORIGINAL_QKV_CODE in self_attn_forward, "Original QKV code not found" |
| 192 | + assert any( |
| 193 | + qkv_options[0] in self_attn_forward for qkv_options in QKV_PATCHES |
| 194 | + ), "Original QKV code not found" |
172 | 195 | assert ORIGINAL_O_CODE in self_attn_forward, "Original O code not found"
|
173 | 196 |
|
174 |
| - self_attn_forward = self_attn_forward.replace(ORIGINAL_QKV_CODE, PATCHED_QKV_CODE) |
| 197 | + for qkv_orig, qkv_patched in QKV_PATCHES: |
| 198 | + if qkv_orig in self_attn_forward: |
| 199 | + self_attn_forward = self_attn_forward.replace( |
| 200 | + qkv_orig, |
| 201 | + qkv_patched, |
| 202 | + ) |
| 203 | + break |
175 | 204 | self_attn_forward = self_attn_forward.replace(ORIGINAL_O_CODE, PATCHED_O_CODE)
|
176 | 205 | self_attn_forward = self_attn_forward.replace(
|
177 | 206 | "def forward(",
|
|
0 commit comments