5
5
import torch
6
6
import torch .nn as nn
7
7
import torch .nn .functional as F
8
+ from executorch .examples .models .llama .lora import LoRALinear
8
9
from executorch .examples .models .llama .model_args import ModelArgs
9
10
from executorch .examples .models .llama .norm import RMSNorm
10
11
from executorch .examples .models .llama .rope import Rope
@@ -325,7 +326,20 @@ def update(
325
326
326
327
@register_attention ("mha" )
327
328
class AttentionMHA (Attention ):
328
- def __init__ (self , args : ModelArgs , layer_id : int , rope : Rope ):
329
+ def __init__ (
330
+ self ,
331
+ args : ModelArgs ,
332
+ layer_id : int ,
333
+ rope : Rope ,
334
+ ):
335
+ """
336
+ Multi-head attention layer.
337
+
338
+ Args:
339
+ args (ModelArgs): Model configuration parameters.
340
+ layer_id (int): Layer index.
341
+ rope (Rope): Rotary position embedding module.
342
+ """
329
343
super ().__init__ ()
330
344
self .use_kv_cache = args .use_kv_cache
331
345
self .n_heads = args .n_heads
@@ -350,16 +364,60 @@ def __init__(self, args: ModelArgs, layer_id: int, rope: Rope):
350
364
self .q_norm_fn = RMSNorm (q_norm_dim , eps = args .norm_eps )
351
365
self .k_norm_fn = RMSNorm (k_norm_dim , eps = args .norm_eps )
352
366
353
- self .wq = nn .Linear (
354
- self .dim , self .n_heads * self .head_dim , bias = self .attention_qkv_bias
367
+ self .wq = (
368
+ LoRALinear (
369
+ in_dim = args .dim ,
370
+ out_dim = args .n_heads * args .head_dim ,
371
+ rank = args .r ,
372
+ alpha = args .lora_alpha ,
373
+ dropout = 0.0 ,
374
+ use_bias = args .attention_qkv_bias ,
375
+ )
376
+ if args .target_modules is not None and "q_proj" in args .target_modules
377
+ else nn .Linear (
378
+ self .dim , self .n_heads * self .head_dim , bias = self .attention_qkv_bias
379
+ )
355
380
)
356
- self .wk = nn .Linear (
357
- self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
381
+ self .wk = (
382
+ LoRALinear (
383
+ in_dim = args .dim ,
384
+ out_dim = args .n_kv_heads * args .head_dim ,
385
+ rank = args .r ,
386
+ alpha = args .lora_alpha ,
387
+ dropout = 0.0 ,
388
+ use_bias = args .attention_qkv_bias ,
389
+ )
390
+ if args .target_modules is not None and "k_proj" in args .target_modules
391
+ else nn .Linear (
392
+ self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
393
+ )
358
394
)
359
- self .wv = nn .Linear (
360
- self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
395
+ self .wv = (
396
+ LoRALinear (
397
+ in_dim = args .dim ,
398
+ out_dim = args .n_kv_heads * args .head_dim ,
399
+ rank = args .r ,
400
+ alpha = args .lora_alpha ,
401
+ dropout = 0.0 ,
402
+ use_bias = args .attention_qkv_bias ,
403
+ )
404
+ if args .target_modules is not None and "v_proj" in args .target_modules
405
+ else nn .Linear (
406
+ self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
407
+ )
408
+ )
409
+ self .wo = (
410
+ LoRALinear (
411
+ in_dim = args .n_kv_heads * args .head_dim ,
412
+ out_dim = args .dim ,
413
+ rank = args .r ,
414
+ alpha = args .lora_alpha ,
415
+ dropout = 0.0 ,
416
+ use_bias = args .attention_qkv_bias ,
417
+ )
418
+ if args .target_modules is not None and "output_proj" in args .target_modules
419
+ else nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
361
420
)
362
- self .wo = nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
363
421
364
422
self .layer_id = layer_id
365
423
0 commit comments