55import torch
66import torch .nn as nn
77import torch .nn .functional as F
8+ from executorch .examples .models .llama .lora import LoRALinear
89from executorch .examples .models .llama .model_args import ModelArgs
910from executorch .examples .models .llama .norm import RMSNorm
1011from executorch .examples .models .llama .rope import Rope
@@ -372,8 +373,7 @@ def __init__(
372373 dropout = 0.0 ,
373374 use_bias = args .attention_qkv_bias ,
374375 )
375- if args .target_modules is not None
376- and "q_proj" in args .target_modules
376+ if args .target_modules is not None and "q_proj" in args .target_modules
377377 else nn .Linear (
378378 self .dim , self .n_heads * self .head_dim , bias = self .attention_qkv_bias
379379 )
@@ -387,8 +387,7 @@ def __init__(
387387 dropout = 0.0 ,
388388 use_bias = args .attention_qkv_bias ,
389389 )
390- if args .target_modules is not None
391- and "k_proj" in args .target_modules
390+ if args .target_modules is not None and "k_proj" in args .target_modules
392391 else nn .Linear (
393392 self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
394393 )
@@ -402,8 +401,7 @@ def __init__(
402401 dropout = 0.0 ,
403402 use_bias = args .attention_qkv_bias ,
404403 )
405- if args .target_modules is not None
406- and "v_proj" in args .target_modules
404+ if args .target_modules is not None and "v_proj" in args .target_modules
407405 else nn .Linear (
408406 self .dim , self .n_kv_heads * self .head_dim , bias = self .attention_qkv_bias
409407 )
@@ -417,8 +415,7 @@ def __init__(
417415 dropout = 0.0 ,
418416 use_bias = args .attention_qkv_bias ,
419417 )
420- if args .target_modules is not None
421- and "output_proj" in args .target_modules
418+ if args .target_modules is not None and "output_proj" in args .target_modules
422419 else nn .Linear (self .n_heads * self .head_dim , self .dim , bias = False )
423420 )
424421
0 commit comments