2020 Transformer ,
2121 TransformerBlock ,
2222)
23+ from executorch .examples .models .llama .lora import LoRALinear
2324from executorch .examples .models .llama .model_args import ModelArgs
2425from executorch .examples .models .llama .rope import Rope
2526
27+ from torchtune .models import convert_weights
28+
2629try :
2730 from .fairseq2 import convert_to_llama_checkpoint
2831
@@ -37,6 +40,86 @@ def convert_to_llama_checkpoint(**kwargs):
3740from ..model_base import EagerModelBase
3841
3942
43+ def construct_llm (model_args : ModelArgs ) -> Transformer :
44+ if model_args .attention_type not in ATTENTION_REGISTRY :
45+ raise ValueError (
46+ f"Unknown attention type: { model_args .attention_type } . "
47+ f"Available: { list (ATTENTION_REGISTRY .keys ())} "
48+ )
49+
50+ rope = Rope (model_args )
51+ layers = torch .nn .ModuleList ()
52+ cls = ATTENTION_REGISTRY [model_args .attention_type ]
53+
54+ for layer_id in range (model_args .n_layers ):
55+ wq = (
56+ LoRALinear (
57+ in_dim = model_args .dim ,
58+ out_dim = model_args .n_heads * model_args .head_dim ,
59+ rank = model_args .r , # todo
60+ alpha = model_args .lora_alpha , # todo
61+ dropout = 0.0 ,
62+ use_bias = model_args .attention_qkv_bias ,
63+ )
64+ if "q_proj" in model_args .target_modules
65+ else (
66+ torch .nn .Linear (
67+ model_args .dim ,
68+ model_args .n_heads * model_args .head_dim ,
69+ bias = model_args .attention_qkv_bias ,
70+ )
71+ )
72+ )
73+
74+ wk = (
75+ LoRALinear (
76+ in_dim = model_args .dim ,
77+ out_dim = model_args .n_kv_heads * model_args .head_dim ,
78+ rank = model_args .r , # todo
79+ alpha = model_args .lora_alpha , # todo
80+ dropout = 0.0 ,
81+ use_bias = model_args .attention_qkv_bias ,
82+ )
83+ if "k_proj" in model_args .target_modules
84+ else (
85+ torch .nn .Linear (
86+ model_args .dim ,
87+ model_args .n_kv_heads * model_args .head_dim ,
88+ bias = model_args .attention_qkv_bias ,
89+ )
90+ )
91+ )
92+ wv = (
93+ LoRALinear (
94+ in_dim = model_args .dim ,
95+ out_dim = model_args .n_kv_heads * model_args .head_dim ,
96+ rank = model_args .r , # todo
97+ alpha = model_args .lora_alpha , # todo
98+ dropout = 0.0 ,
99+ use_bias = model_args .attention_qkv_bias ,
100+ )
101+ if "v_proj" in model_args .target_modules
102+ else (
103+ torch .nn .Linear (
104+ model_args .dim ,
105+ model_args .n_kv_heads * model_args .head_dim ,
106+ bias = model_args .attention_qkv_bias ,
107+ )
108+ )
109+ )
110+
111+ # todo
112+ wo = torch .nn .Linear (
113+ model_args .n_heads * model_args .head_dim , model_args .dim , bias = False
114+ )
115+ attention = cls (model_args , layer_id , rope , wq , wk , wv , wo )
116+ transformer_block = TransformerBlock (model_args , attention )
117+ layers .append (transformer_block )
118+
119+ # Construct transformer model.
120+ return Transformer (model_args , layers , rope )
121+
122+
40123class Llama2Model (EagerModelBase ):
41124 def __init__ (self , ** kwargs ):
42125 resource_dir = get_default_model_resource_dir (__file__ )
@@ -49,6 +132,10 @@ def __init__(self, **kwargs):
49132 # Params file.
50133 params_path = kwargs .get ("params" , None )
51134
135+ # Adapter
136+ adapter_checkpoint = kwargs .get ("adapter_checkpoint" , None )
137+ adapter_config = kwargs .get ("adapter_config" , None )
138+
52139 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
53140 self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
54141 self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
@@ -132,6 +219,22 @@ def __init__(self, **kwargs):
132219 with open (params_path , "r" ) as f :
133220 params = json .loads (f .read ())
134221
222+ # Get adapter checkpoint and config.
223+ adapter_checkpoint = {}
224+ adapter_config = {}
225+ adapter_checkpoint_path = kwargs .get ("adapter_checkpoint" , None )
226+ if adapter_checkpoint_path :
227+ adapter_checkpoint = torch .load (
228+ adapter_checkpoint_path , map_location = device , mmap = True
229+ )
230+ adapter_checkpoint = convert_weights .tune_to_meta (adapter_checkpoint )
231+
232+ adapter_config = kwargs .get ("adapter_config" , None )
233+ with open (adapter_config , "r" ) as f :
234+ adapter_config = json .loads (f .read ())
235+
236+ checkpoint .update (adapter_checkpoint )
237+
135238 output_prune_map = None
136239 if self .output_prune_map_path is not None :
137240 with open (self .output_prune_map_path , "r" ) as f :
@@ -156,6 +259,7 @@ def __init__(self, **kwargs):
156259 output_prune_map = output_prune_map ,
157260 enable_dynamic_shape = self .enable_dynamic_shape ,
158261 ** params ,
262+ ** adapter_config ,
159263 )
160264
161265 if model_args .use_scaled_rope :
@@ -177,23 +281,7 @@ def __init__(self, **kwargs):
177281 # They possess all other metadata a tensor carries such as size, stride, requires_grad.
178282 with torch .device ("meta" ):
179283 # Model itself is loaded in default dtype, fp32.
180-
181- # Construct attention layers.
182- rope = Rope (model_args )
183- if model_args .attention_type not in ATTENTION_REGISTRY :
184- raise ValueError (
185- f"Unknown attention type: { model_args .attention_type } . "
186- f"Available: { list (ATTENTION_REGISTRY .keys ())} "
187- )
188- layers = torch .nn .ModuleList ()
189- cls = ATTENTION_REGISTRY [model_args .attention_type ]
190- for layer_id in range (model_args .n_layers ):
191- attention = cls (model_args , layer_id , rope )
192- transformer_block = TransformerBlock (model_args , attention )
193- layers .append (transformer_block )
194-
195- # Construct transformer model.
196- self .model_ = Transformer (model_args , layers , rope )
284+ self .model_ = construct_llm (model_args )
197285
198286 # Get checkpoint dtype.
199287 if checkpoint :
0 commit comments