@@ -138,18 +138,25 @@ def parse_args():
138138 if args .model_name not in MODEL_NAME_TO_MODEL :
139139 raise RuntimeError (f"Available models are { list (MODEL_NAME_TO_MODEL .keys ())} ." )
140140
141- llm_config = LlmConfig ()
142141 if args .model_name == "llama2" :
142+ # Building LLM example.
143+ llm_config = LlmConfig ()
143144 if args .checkpoint :
144145 llm_config .base .checkpoint = args .checkpoint
145146 if args .params :
146147 llm_config .base .params = args .params
147148 llm_config .model .use_kv_cache = True
148- model , example_inputs , _ , _ = EagerModelFactory .create_model (
149- module_name = MODEL_NAME_TO_MODEL [args .model_name ][0 ],
150- model_class_name = MODEL_NAME_TO_MODEL [args .model_name ][1 ],
151- llm_config = llm_config ,
152- )
149+ model , example_inputs , _ , _ = EagerModelFactory .create_model (
150+ module_name = MODEL_NAME_TO_MODEL [args .model_name ][0 ],
151+ model_class_name = MODEL_NAME_TO_MODEL [args .model_name ][1 ],
152+ llm_config = llm_config ,
153+ )
154+ else :
155+ # Building non-LLM example.
156+ model , example_inputs , _ , _ = EagerModelFactory .create_model (
157+ module_name = MODEL_NAME_TO_MODEL [args .model_name ][0 ],
158+ model_class_name = MODEL_NAME_TO_MODEL [args .model_name ][1 ],
159+ )
153160
154161 model = model .eval ()
155162
0 commit comments