2020 Transformer ,
2121 TransformerBlock ,
2222)
23+ from executorch .examples .models .llama .lora import LoRALinear
2324from executorch .examples .models .llama .model_args import ModelArgs
2425from executorch .examples .models .llama .rope import Rope
2526
27+ from torchtune .models import convert_weights
28+
2629try :
2730 from .fairseq2 import convert_to_llama_checkpoint
2831
@@ -37,6 +40,87 @@ def convert_to_llama_checkpoint(**kwargs):
3740from ..model_base import EagerModelBase
3841
3942
43+ def construct_llm (model_args : ModelArgs ) -> Transformer :
44+ if model_args .attention_type not in ATTENTION_REGISTRY :
45+ raise ValueError (
46+ f"Unknown attention type: { model_args .attention_type } . "
47+ f"Available: { list (ATTENTION_REGISTRY .keys ())} "
48+ )
49+
50+ rope = Rope (model_args )
51+ layers = torch .nn .ModuleList ()
52+ cls = ATTENTION_REGISTRY [model_args .attention_type ]
53+
54+ wq = (
55+ LoRALinear (
56+ in_dim = model_args .dim ,
57+ out_dim = model_args .n_heads * model_args .head_dim ,
58+ rank = model_args .r , # todo
59+ alpha = model_args .lora_alpha , # todo
60+ dropout = 0.0 ,
61+ use_bias = model_args .attention_qkv_bias ,
62+ )
63+ if "q_proj" in model_args .target_modules
64+ else (
65+ torch .nn .Linear (
66+ model_args .dim ,
67+ model_args .n_heads * model_args .head_dim ,
68+ bias = model_args .attention_qkv_bias ,
69+ )
70+ )
71+ )
72+
73+ wk = (
74+ LoRALinear (
75+ in_dim = model_args .dim ,
76+ out_dim = model_args .n_kv_heads * model_args .head_dim ,
77+ rank = model_args .r , # todo
78+ alpha = model_args .lora_alpha , # todo
79+ dropout = 0.0 ,
80+ use_bias = model_args .attention_qkv_bias ,
81+ )
82+ if "k_proj" in model_args .target_modules
83+ else (
84+ torch .nn .Linear (
85+ model_args .dim ,
86+ model_args .n_kv_heads * model_args .head_dim ,
87+ bias = model_args .attention_qkv_bias ,
88+ )
89+ )
90+ )
91+ wv = (
92+ LoRALinear (
93+ in_dim = model_args .dim ,
94+ out_dim = model_args .n_kv_heads * model_args .head_dim ,
95+ rank = model_args .r , # todo
96+ alpha = model_args .lora_alpha , # todo
97+ dropout = 0.0 ,
98+ use_bias = model_args .attention_qkv_bias ,
99+ )
100+ if "v_proj" in model_args .target_modules
101+ else (
102+ torch .nn .Linear (
103+ model_args .dim ,
104+ model_args .n_kv_heads * model_args .head_dim ,
105+ bias = model_args .attention_qkv_bias ,
106+ )
107+ )
108+ )
109+
110+ # todo
111+ wo = torch .nn .Linear (
112+ model_args .n_heads * model_args .head_dim , model_args .dim , bias = False
113+ )
114+
115+ for layer_id in range (model_args .n_layers ):
116+ attention = cls (model_args , layer_id , rope , wq , wk , wv , wo )
117+ transformer_block = TransformerBlock (model_args , attention )
118+ layers .append (transformer_block )
119+
120+ # Construct transformer model.
121+ return Transformer (model_args , layers , rope )
122+
123+
40124class Llama2Model (EagerModelBase ):
41125 def __init__ (self , ** kwargs ):
42126 resource_dir = get_default_model_resource_dir (__file__ )
@@ -49,6 +133,10 @@ def __init__(self, **kwargs):
49133 # Params file.
50134 params_path = kwargs .get ("params" , None )
51135
136+ # Adapter
137+ adapter_checkpoint = kwargs .get ("adapter_checkpoint" , None )
138+ adapter_config = kwargs .get ("adapter_config" , None )
139+
52140 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
53141 self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
54142 self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
@@ -132,6 +220,22 @@ def __init__(self, **kwargs):
132220 with open (params_path , "r" ) as f :
133221 params = json .loads (f .read ())
134222
223+ # Get adapter checkpoint and config.
224+ adapter_checkpoint = {}
225+ adapter_config = {}
226+ adapter_checkpoint_path = kwargs .get ("adapter_checkpoint" , None )
227+ if adapter_checkpoint_path :
228+ adapter_checkpoint = torch .load (
229+ adapter_checkpoint_path , map_location = device , mmap = True
230+ )
231+ adapter_checkpoint = convert_weights .tune_to_meta (adapter_checkpoint )
232+
233+ adapter_config = kwargs .get ("adapter_config" , None )
234+ with open (adapter_config , "r" ) as f :
235+ adapter_config = json .loads (f .read ())
236+
237+ checkpoint .update (adapter_checkpoint )
238+
135239 output_prune_map = None
136240 if self .output_prune_map_path is not None :
137241 with open (self .output_prune_map_path , "r" ) as f :
@@ -156,6 +260,7 @@ def __init__(self, **kwargs):
156260 output_prune_map = output_prune_map ,
157261 enable_dynamic_shape = self .enable_dynamic_shape ,
158262 ** params ,
263+ ** adapter_config ,
159264 )
160265
161266 if model_args .use_scaled_rope :
@@ -177,23 +282,7 @@ def __init__(self, **kwargs):
177282 # They possess all other metadata a tensor carries such as size, stride, requires_grad.
178283 with torch .device ("meta" ):
179284 # Model itself is loaded in default dtype, fp32.
180-
181- # Construct attention layers.
182- rope = Rope (model_args )
183- if model_args .attention_type not in ATTENTION_REGISTRY :
184- raise ValueError (
185- f"Unknown attention type: { model_args .attention_type } . "
186- f"Available: { list (ATTENTION_REGISTRY .keys ())} "
187- )
188- layers = torch .nn .ModuleList ()
189- cls = ATTENTION_REGISTRY [model_args .attention_type ]
190- for layer_id in range (model_args .n_layers ):
191- attention = cls (model_args , layer_id , rope )
192- transformer_block = TransformerBlock (model_args , attention )
193- layers .append (transformer_block )
194-
195- # Construct transformer model.
196- self .model_ = Transformer (model_args , layers , rope )
285+ self .model_ = construct_llm (model_args )
197286
198287 # Get checkpoint dtype.
199288 if checkpoint :
0 commit comments