File tree Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Expand file tree Collapse file tree 2 files changed +7
-8
lines changed Original file line number Diff line number Diff line change @@ -1226,7 +1226,7 @@ def _load_llama_model(
12261226 EagerModelFactory .create_model (
12271227 module_name ,
12281228 model_class_name ,
1229- llm_config = llm_config ,
1229+ model_args = { " llm_config" : llm_config } ,
12301230 )
12311231 )
12321232
Original file line number Diff line number Diff line change @@ -36,19 +36,18 @@ def convert_to_llama_checkpoint(**kwargs):
3636
3737
3838class Llama2Model (EagerModelBase ):
39- def __init__ (self , ** kwargs ):
39+ def __init__ (self , llm_config ):
4040 resource_dir = get_default_model_resource_dir (__file__ )
4141
42+ self .llm_config = llm_config
43+
4244 # Use single checkpoint file.
43- checkpoint_path = kwargs . get ( " checkpoint" , None )
45+ checkpoint_path = self . llm_config . base . checkpoint
4446 # Check if checkpoint_dir was provided for a sharded checkpoint.
45- checkpoint_dir = kwargs . get ( " checkpoint_dir" , None )
47+ checkpoint_dir = self . llm_config . base . checkpoint_dir
4648
4749 # Params file.
48- params_path = kwargs .get ("params" , None )
49-
50- self .llm_config = kwargs .get ("llm_config" )
51- assert self .llm_config is not None , "llm_config must be provided"
50+ params_path = self .llm_config .base .params
5251
5352 self .use_kv_cache = self .llm_config .model .use_kv_cache
5453 self .use_sdpa_with_kv_cache_op = self .llm_config .model .use_sdpa_with_kv_cache
You can’t perform that action at this time.
0 commit comments