diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 5ccbc987b4d..4aa34fce9f8 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -138,18 +138,25 @@ def parse_args(): if args.model_name not in MODEL_NAME_TO_MODEL: raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.") - llm_config = LlmConfig() if args.model_name == "llama2": + # Building LLM example. + llm_config = LlmConfig() if args.checkpoint: llm_config.base.checkpoint = args.checkpoint if args.params: llm_config.base.params = args.params llm_config.model.use_kv_cache = True - model, example_inputs, _, _ = EagerModelFactory.create_model( - module_name=MODEL_NAME_TO_MODEL[args.model_name][0], - model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], - llm_config=llm_config, - ) + model, example_inputs, _, _ = EagerModelFactory.create_model( + module_name=MODEL_NAME_TO_MODEL[args.model_name][0], + model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], + llm_config=llm_config, + ) + else: + # Building non-LLM example. + model, example_inputs, _, _ = EagerModelFactory.create_model( + module_name=MODEL_NAME_TO_MODEL[args.model_name][0], + model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], + ) model = model.eval()