88
99import json
1010import os
11- from pathlib import Path
11+ from typing import Dict , Tuple
1212
1313import torch
14+ from executorch .examples .models .checkpoint import (
15+ get_checkpoint_dtype ,
16+ get_default_model_resource_dir ,
17+ )
1418
1519from executorch .examples .models .llama2 .llama_transformer import ModelArgs , Transformer
1620
@@ -30,48 +34,31 @@ def convert_to_llama_checkpoint(**kwargs):
3034
3135class Llama2Model (EagerModelBase ):
3236 def __init__ (self , ** kwargs ):
33- import pkg_resources
34-
35- # default path to the resource file
36- # It currently supports 3 ways of specifying the checkpoint location:
37- # 1. Using default path locates in examples/models/llama2/params
38- # 2. Passing in the checkpoint path and params via kwargs
39- # 3. Using the path from pkg_resources, only works with buck2
40- try :
41- # The 3rd way, if we can import this path, we are running with buck2, all resources can be accessed with pkg_resources.resource_filename
42- # pyre-ignore
43- from executorch .examples .models .llama2 import params
44-
45- ckpt_dir = Path (
46- pkg_resources .resource_filename (
47- "executorch.examples.models.llama2" , "params"
48- )
49- )
50- except :
51- # The 1st way
52- ckpt_dir = Path (__file__ ).absolute ().parent / "params"
53-
54- # Check if checkpoint_dir was provided for a sharded checkpoint.
55- checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
37+ resource_dir = get_default_model_resource_dir (__file__ )
5638
5739 # Use single checkpoint file.
58- checkpoint_path = kwargs .get ("checkpoint" , ckpt_dir / "demo_rand_params.pth" )
40+ checkpoint_path = kwargs .get (
41+ "checkpoint" , resource_dir / "demo_rand_params.pth"
42+ )
43+ params_path = kwargs .get ("params" , resource_dir / "demo_config.json" )
5944
60- params_path = kwargs .get ("params" , ckpt_dir / "demo_config.json" )
45+ # Check if checkpoint_dir was provided for a sharded checkpoint.
46+ checkpoint_dir = kwargs .get ("checkpoint_dir" , None )
6147
6248 self .use_kv_cache = kwargs .get ("use_kv_cache" , False )
6349 self .use_sdpa_with_kv_cache_op = kwargs .get ("use_sdpa_with_kv_cache" , False )
6450 self .generate_full_logits = kwargs .get ("generate_full_logits" , False )
6551 self .enable_dynamic_shape = kwargs .get ("enable_dynamic_shape" , False )
6652 self .output_prune_map_path = kwargs .get ("output_prune_map_path" , None )
67-
6853 self .max_seq_len = kwargs .get ("max_seq_len" , 128 )
6954 self .args = kwargs .get ("args" , None )
55+
7056 # The example is using a dummy small model with random weights for demo purpose only.
71- # Follow the instruction in https://github.com/facebookresearch/llama to download the model
57+ # Follow the instruction in https://github.com/facebookresearch/llama to download the model.
7258 device = "cpu"
7359 # flake8: noqa: TOR102
7460 cps = []
61+ # Load sharded checkpoint.
7562 if checkpoint_dir is not None :
7663 # Load multiple checkpoint; ignore the single path.
7764 checkpoint_path = None
@@ -98,8 +85,11 @@ def __init__(self, **kwargs):
9885 else :
9986 # Do not duplicate layers shared between each checkpoint.
10087 checkpoint [key ] = cps [0 ][key ]
88+ # Load single checkpoint.
10189 else :
10290 checkpoint = torch .load (checkpoint_path , map_location = device , mmap = True )
91+
92+ # If given checkpoint is fairseq, convert to llama checkpoint.
10393 fairseq2_checkpoint = kwargs .get ("fairseq2" , False )
10494 if fairseq2_checkpoint :
10595 print ("Using fairseq2 checkpoint" )
@@ -108,12 +98,12 @@ def __init__(self, **kwargs):
10898 # NB: some checkpoint contains a "model" field, which is the actual weights dict
10999 checkpoint = checkpoint ["model" ]
110100
101+ # Check if user gave a fairseq2 checkpoint unknowingly without specifying --fairseq2.
111102 if (not fairseq2_checkpoint ) and checkpoint .get (
112103 "final_proj.weight" , None
113104 ) is not None :
114- print (
105+ raise ValueError (
115106 """
116-
117107************************************************************
118108This looks like a Fairseq2 checkpoint (based on the presence
119109of `final_proj.weight`.
@@ -125,44 +115,28 @@ def __init__(self, **kwargs):
125115"""
126116 )
127117
128- # get checkpoint dtype
129- self .dtype = None
130- if len (checkpoint ) > 0 :
131- first_key = next (iter (checkpoint ))
132- first = checkpoint [first_key ]
133- self .dtype = first .dtype
134- mismatched_dtypes = [
135- (key , value .dtype )
136- for key , value in checkpoint .items ()
137- if value .dtype != self .dtype
138- ]
139- if len (mismatched_dtypes ) > 0 :
140- print (
141- f"Mixed dtype model. Dtype of { first_key } : { first .dtype } . Mismatches in the checkpoint: { mismatched_dtypes } "
142- )
118+ # Get checkpoint dtype.
119+ self .dtype = get_checkpoint_dtype (checkpoint )
120+
143121 with open (params_path , "r" ) as f :
144122 params = json .loads (f .read ())
145123 output_prune_map = None
146124 if self .output_prune_map_path is not None :
147125 with open (self .output_prune_map_path , "r" ) as f :
148126 output_prune_map = json .load (f )
149- # change keys from string to int (json only supports string keys)
127+ # Change keys from string to int (json only supports string keys).
150128 output_prune_map = {int (k ): v for (k , v ) in output_prune_map .items ()}
151- max_seq_len = self .max_seq_len
152- max_batch_size = 1
129+
153130 model_args : ModelArgs = ModelArgs (
154- max_seq_len = max_seq_len ,
155- max_batch_size = max_batch_size ,
131+ max_seq_len = self . max_seq_len ,
132+ max_batch_size = 1 ,
156133 use_kv_cache = self .use_kv_cache ,
157134 use_sdpa_with_kv_cache_op = self .use_sdpa_with_kv_cache_op ,
158135 generate_full_logits = self .generate_full_logits ,
159136 output_prune_map = output_prune_map ,
160137 enable_dynamic_shape = self .enable_dynamic_shape ,
161138 ** params ,
162139 )
163- if kwargs .get ("fairseq2" , False ):
164- print ("Using fairseq2 checkpoint" )
165- checkpoint = convert_to_llama_checkpoint (checkpoint = checkpoint )
166140 if kwargs .get ("verbose" , False ):
167141 print ("============= weights ================" )
168142 print ("{key} : {weights.numel()} : {weights.size()}" )
@@ -234,13 +208,13 @@ def __init__(self, **kwargs):
234208 print (unexpected )
235209 print ("============= /unexpected ================" )
236210
237- # prune the output layer if output_prune_map is provided
211+ # Prune the output layer if output_prune_map is provided
238212 if output_prune_map is not None :
239213 from .source_transformation .prune_output import prune_output_vocab
240214
241215 self .model_ = prune_output_vocab (self .model_ , output_prune_map )
242216
243- def get_eager_model (self ):
217+ def get_eager_model (self ) -> torch . nn . Module :
244218 if self .dtype :
245219 # convert to the type of the provided checkpoint
246220 # input and output are torch.long, so signature unchanged
0 commit comments