99import logging
1010import os
1111import re
12- from typing import Any , Dict
12+ import sys
13+ from typing import Any , Dict , List
1314
15+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
1416from examples .models import MODEL_NAME_TO_MODEL
1517
1618
4547}
4648
4749
50+ def extract_all_configs (data , target_os = None ):
51+ if isinstance (data , dict ):
52+ # If target_os is specified, include "xplat" and the specified branch
53+ include_branches = {"xplat" , target_os } if target_os else data .keys ()
54+ return [
55+ v
56+ for key , value in data .items ()
57+ if key in include_branches
58+ for v in extract_all_configs (value , target_os )
59+ ]
60+ elif isinstance (data , list ):
61+ return [v for item in data for v in extract_all_configs (item , target_os )]
62+ else :
63+ return [data ]
64+
65+
66+ def generate_compatible_configs (model_name : str , target_os = None ) -> List [str ]:
67+ """
68+ Generate a list of compatible benchmark configurations for a given model name and target OS.
69+
70+ Args:
71+ model_name (str): The name of the model to generate configurations for.
72+ target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
73+
74+ Returns:
75+ List[str]: A list of compatible benchmark configurations.
76+
77+ Raises:
78+ None
79+
80+ Example:
81+ generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
82+ """
83+ configs = []
84+ if is_valid_huggingface_model_id (model_name ):
85+ if model_name .startswith ("meta-llama/" ):
86+ # LLaMA models
87+ repo_name = model_name .split ("meta-llama/" )[1 ]
88+ if "qlora" in repo_name .lower ():
89+ configs .append ("llama3_qlora" )
90+ elif "spinquant" in repo_name .lower ():
91+ configs .append ("llama3_spinquant" )
92+ else :
93+ configs .append ("llama3_fb16" )
94+ configs .extend (
95+ [
96+ config
97+ for config in BENCHMARK_CONFIGS .get (target_os , [])
98+ if config .startswith ("llama" )
99+ ]
100+ )
101+ else :
102+ # Non-LLaMA models
103+ configs .append ("hf_xnnpack_fp32" )
104+ elif model_name in MODEL_NAME_TO_MODEL :
105+ # ExecuTorch in-tree non-GenAI models
106+ configs .append ("xnnpack_q8" )
107+ if target_os != "xplat" :
108+ # Add OS-specific configs
109+ configs .extend (
110+ [
111+ config
112+ for config in BENCHMARK_CONFIGS .get (target_os , [])
113+ if not config .startswith ("llama" )
114+ ]
115+ )
116+ else :
117+ # Skip unknown models with a warning
118+ logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
119+
120+ return configs
121+
122+
48123def parse_args () -> Any :
49124 """
50125 Parse command-line arguments.
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82157 type = comma_separated , # Use the custom parser for comma-separated values
83158 help = f"Comma-separated device names. Available devices: { list (DEVICE_POOLS .keys ())} " ,
84159 )
160+ parser .add_argument (
161+ "--configs" ,
162+ type = comma_separated , # Use the custom parser for comma-separated values
163+ help = f"Comma-separated benchmark configs. Available configs: { extract_all_configs (BENCHMARK_CONFIGS )} " ,
164+ )
85165
86166 return parser .parse_args ()
87167
@@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
98178 set_output("benchmark_configs", {"include": [...]})
99179 """
100180
101- if os .getenv ("GITHUB_OUTPUT" ):
102- print (f"Setting { val } to GitHub output" )
103- with open (str (os .getenv ("GITHUB_OUTPUT" )), "a" ) as env :
104- print (f"{ name } ={ val } " , file = env )
105- else :
181+ github_output = os .getenv ("GITHUB_OUTPUT" )
182+ if not github_output :
183+ print (f"::set-output name={ name } ::{ val } " )
184+ return
185+
186+ try :
187+ with open (github_output , "a" ) as env :
188+ env .write (f"{ name } ={ val } \n " )
189+ except PermissionError :
190+ # Fall back to printing in case of permission error in unit tests
106191 print (f"::set-output name={ name } ::{ val } " )
107192
108193
@@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123208 return bool (re .match (pattern , model_name ))
124209
125210
126- def get_benchmark_configs () -> Dict [str , Dict ]:
211+ def get_benchmark_configs () -> Dict [str , Dict ]: # noqa: C901
127212 """
128213 Gather benchmark configurations for a given set of models on the target operating system and devices.
129214
@@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153238 }
154239 """
155240 args = parse_args ()
156- target_os = args .os
157241 devices = args .devices
158242 models = args .models
243+ target_os = args .os
244+ target_configs = args .configs
159245
160246 benchmark_configs = {"include" : []}
161247
162248 for model_name in models :
163249 configs = []
164- if is_valid_huggingface_model_id (model_name ):
165- if model_name .startswith ("meta-llama/" ):
166- # LLaMA models
167- repo_name = model_name .split ("meta-llama/" )[1 ]
168- if "qlora" in repo_name .lower ():
169- configs .append ("llama3_qlora" )
170- elif "spinquant" in repo_name .lower ():
171- configs .append ("llama3_spinquant" )
172- else :
173- configs .append ("llama3_fb16" )
174- configs .extend (
175- [
176- config
177- for config in BENCHMARK_CONFIGS .get (target_os , [])
178- if config .startswith ("llama" )
179- ]
250+ configs .extend (generate_compatible_configs (model_name , target_os ))
251+ print (f"Discovered all supported configs for model '{ model_name } ': { configs } " )
252+ if target_configs is not None :
253+ for config in target_configs :
254+ if config not in configs :
255+ raise Exception (
256+ f"Unsupported config '{ config } ' for model '{ model_name } ' on '{ target_os } '. Skipped.\n "
257+ f"Supported configs are: { configs } "
180258 )
181- else :
182- # Non-LLaMA models
183- configs .append ("hf_xnnpack_fp32" )
184- elif model_name in MODEL_NAME_TO_MODEL :
185- # ExecuTorch in-tree non-GenAI models
186- configs .append ("xnnpack_q8" )
187- configs .extend (
188- [
189- config
190- for config in BENCHMARK_CONFIGS .get (target_os , [])
191- if not config .startswith ("llama" )
192- ]
193- )
194- else :
195- # Skip unknown models with a warning
196- logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
197- continue
259+ configs = target_configs
260+ print (f"Using provided configs { configs } for model '{ model_name } '" )
198261
199262 # Add configurations for each valid device
200263 for device in devices :
0 commit comments