Skip to content

Commit 180fe7f

Browse files
committed
Update on "Add LoRA linear definition"
^ Add lora linear definition. Pull out linears from attention, and allow custom linear (eg. lora linear) to be passed in. If none, construct linear (current behaviour). Differential Revision: [D73953776](https://our.internmc.facebook.com/intern/diff/D73953776/) [ghstack-poisoned]
2 parents 8bcd073 + b4b4076 commit 180fe7f

File tree

1 file changed

+5
-9
lines changed

1 file changed

+5
-9
lines changed

examples/models/llama/attention.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,8 @@
55
import torch
66
import torch.nn as nn
77
import torch.nn.functional as F
8-
from executorch.examples.models.llama.model_args import ModelArgs
98
from executorch.examples.models.llama.lora import LoRALinear
9+
from executorch.examples.models.llama.model_args import ModelArgs
1010
from executorch.examples.models.llama.norm import RMSNorm
1111
from executorch.examples.models.llama.rope import Rope
1212

@@ -373,8 +373,7 @@ def __init__(
373373
dropout=0.0,
374374
use_bias=args.attention_qkv_bias,
375375
)
376-
if args.target_modules is not None
377-
and "q_proj" in args.target_modules
376+
if args.target_modules is not None and "q_proj" in args.target_modules
378377
else nn.Linear(
379378
self.dim, self.n_heads * self.head_dim, bias=self.attention_qkv_bias
380379
)
@@ -388,8 +387,7 @@ def __init__(
388387
dropout=0.0,
389388
use_bias=args.attention_qkv_bias,
390389
)
391-
if args.target_modules is not None
392-
and "k_proj" in args.target_modules
390+
if args.target_modules is not None and "k_proj" in args.target_modules
393391
else nn.Linear(
394392
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
395393
)
@@ -403,8 +401,7 @@ def __init__(
403401
dropout=0.0,
404402
use_bias=args.attention_qkv_bias,
405403
)
406-
if args.target_modules is not None
407-
and "v_proj" in args.target_modules
404+
if args.target_modules is not None and "v_proj" in args.target_modules
408405
else nn.Linear(
409406
self.dim, self.n_kv_heads * self.head_dim, bias=self.attention_qkv_bias
410407
)
@@ -418,8 +415,7 @@ def __init__(
418415
dropout=0.0,
419416
use_bias=args.attention_qkv_bias,
420417
)
421-
if args.target_modules is not None
422-
and "output_proj" in args.target_modules
418+
if args.target_modules is not None and "output_proj" in args.target_modules
423419
else nn.Linear(self.n_heads * self.head_dim, self.dim, bias=False)
424420
)
425421

0 commit comments

Comments
 (0)