From 1698880eeb01a83f080124a2bd2e9d3069e9cb12 Mon Sep 17 00:00:00 2001 From: Jack Zhang <32371937+jackzhxng@users.noreply.github.com> Date: Tue, 10 Jun 2025 14:40:30 -0700 Subject: [PATCH] Fix mps example for non-LLMs --- examples/apple/mps/scripts/mps_example.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/examples/apple/mps/scripts/mps_example.py b/examples/apple/mps/scripts/mps_example.py index 5ccbc987b4d..4aa34fce9f8 100644 --- a/examples/apple/mps/scripts/mps_example.py +++ b/examples/apple/mps/scripts/mps_example.py @@ -138,18 +138,25 @@ def parse_args(): if args.model_name not in MODEL_NAME_TO_MODEL: raise RuntimeError(f"Available models are {list(MODEL_NAME_TO_MODEL.keys())}.") - llm_config = LlmConfig() if args.model_name == "llama2": + # Building LLM example. + llm_config = LlmConfig() if args.checkpoint: llm_config.base.checkpoint = args.checkpoint if args.params: llm_config.base.params = args.params llm_config.model.use_kv_cache = True - model, example_inputs, _, _ = EagerModelFactory.create_model( - module_name=MODEL_NAME_TO_MODEL[args.model_name][0], - model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], - llm_config=llm_config, - ) + model, example_inputs, _, _ = EagerModelFactory.create_model( + module_name=MODEL_NAME_TO_MODEL[args.model_name][0], + model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], + llm_config=llm_config, + ) + else: + # Building non-LLM example. + model, example_inputs, _, _ = EagerModelFactory.create_model( + module_name=MODEL_NAME_TO_MODEL[args.model_name][0], + model_class_name=MODEL_NAME_TO_MODEL[args.model_name][1], + ) model = model.eval()