99import logging
1010import os
1111import re
12- from typing import Any , Dict
12+ import sys
13+ from typing import Any , Dict , List , NamedTuple
1314
15+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
1416from examples .models import MODEL_NAME_TO_MODEL
1517
1618
4547}
4648
4749
50+ class DisabledConfig (NamedTuple ):
51+ config_name : str
52+ github_issue : str # Link to the GitHub issue
53+
54+
55+ # Updated DISABLED_CONFIGS
56+ DISABLED_CONFIGS : Dict [str , List [DisabledConfig ]] = {
57+ "resnet50" : [
58+ DisabledConfig (
59+ config_name = "qnn_q8" ,
60+ github_issue = "https://github.com/pytorch/executorch/issues/7892" ,
61+ ),
62+ ],
63+ "w2l" : [
64+ DisabledConfig (
65+ config_name = "qnn_q8" ,
66+ github_issue = "https://github.com/pytorch/executorch/issues/7634" ,
67+ ),
68+ ],
69+ "mobilebert" : [
70+ DisabledConfig (
71+ config_name = "mps" ,
72+ github_issue = "https://github.com/pytorch/executorch/issues/7904" ,
73+ ),
74+ ],
75+ "edsr" : [
76+ DisabledConfig (
77+ config_name = "mps" ,
78+ github_issue = "https://github.com/pytorch/executorch/issues/7905" ,
79+ ),
80+ ],
81+ "llama" : [
82+ DisabledConfig (
83+ config_name = "mps" ,
84+ github_issue = "https://github.com/pytorch/executorch/issues/7907" ,
85+ ),
86+ ],
87+ }
88+
89+
90+ def extract_all_configs (data , target_os = None ):
91+ if isinstance (data , dict ):
92+ # If target_os is specified, include "xplat" and the specified branch
93+ include_branches = {"xplat" , target_os } if target_os else data .keys ()
94+ return [
95+ v
96+ for key , value in data .items ()
97+ if key in include_branches
98+ for v in extract_all_configs (value , target_os )
99+ ]
100+ elif isinstance (data , list ):
101+ return [v for item in data for v in extract_all_configs (item , target_os )]
102+ else :
103+ return [data ]
104+
105+
106+ def generate_compatible_configs (model_name : str , target_os = None ) -> List [str ]:
107+ """
108+ Generate a list of compatible benchmark configurations for a given model name and target OS.
109+
110+ Args:
111+ model_name (str): The name of the model to generate configurations for.
112+ target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
113+
114+ Returns:
115+ List[str]: A list of compatible benchmark configurations.
116+
117+ Raises:
118+ None
119+
120+ Example:
121+ generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
122+ """
123+ configs = []
124+ if is_valid_huggingface_model_id (model_name ):
125+ if model_name .startswith ("meta-llama/" ):
126+ # LLaMA models
127+ repo_name = model_name .split ("meta-llama/" )[1 ]
128+ if "qlora" in repo_name .lower ():
129+ configs .append ("llama3_qlora" )
130+ elif "spinquant" in repo_name .lower ():
131+ configs .append ("llama3_spinquant" )
132+ else :
133+ configs .append ("llama3_fb16" )
134+ configs .extend (
135+ [
136+ config
137+ for config in BENCHMARK_CONFIGS .get (target_os , [])
138+ if config .startswith ("llama" )
139+ ]
140+ )
141+ else :
142+ # Non-LLaMA models
143+ configs .append ("hf_xnnpack_fp32" )
144+ elif model_name in MODEL_NAME_TO_MODEL :
145+ # ExecuTorch in-tree non-GenAI models
146+ configs .append ("xnnpack_q8" )
147+ if target_os != "xplat" :
148+ # Add OS-specific configs
149+ configs .extend (
150+ [
151+ config
152+ for config in BENCHMARK_CONFIGS .get (target_os , [])
153+ if not config .startswith ("llama" )
154+ ]
155+ )
156+ else :
157+ # Skip unknown models with a warning
158+ logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
159+
160+ # Remove disabled configs for the given model
161+ disabled_configs = DISABLED_CONFIGS .get (model_name , [])
162+ disabled_config_names = {disabled .config_name for disabled in disabled_configs }
163+ for disabled in disabled_configs :
164+ print (
165+ f"Excluding disabled config: '{ disabled .config_name } ' for model '{ model_name } ' on '{ target_os } '. Linked GitHub issue: { disabled .github_issue } "
166+ )
167+ configs = [config for config in configs if config not in disabled_config_names ]
168+ return configs
169+
170+
48171def parse_args () -> Any :
49172 """
50173 Parse command-line arguments.
@@ -82,6 +205,11 @@ def comma_separated(value: str):
82205 type = comma_separated , # Use the custom parser for comma-separated values
83206 help = f"Comma-separated device names. Available devices: { list (DEVICE_POOLS .keys ())} " ,
84207 )
208+ parser .add_argument (
209+ "--configs" ,
210+ type = comma_separated , # Use the custom parser for comma-separated values
211+ help = f"Comma-separated benchmark configs. Available configs: { extract_all_configs (BENCHMARK_CONFIGS )} " ,
212+ )
85213
86214 return parser .parse_args ()
87215
@@ -98,11 +226,16 @@ def set_output(name: str, val: Any) -> None:
98226 set_output("benchmark_configs", {"include": [...]})
99227 """
100228
101- if os .getenv ("GITHUB_OUTPUT" ):
102- print (f"Setting { val } to GitHub output" )
103- with open (str (os .getenv ("GITHUB_OUTPUT" )), "a" ) as env :
104- print (f"{ name } ={ val } " , file = env )
105- else :
229+ github_output = os .getenv ("GITHUB_OUTPUT" )
230+ if not github_output :
231+ print (f"::set-output name={ name } ::{ val } " )
232+ return
233+
234+ try :
235+ with open (github_output , "a" ) as env :
236+ env .write (f"{ name } ={ val } \n " )
237+ except PermissionError :
238+ # Fall back to printing in case of permission error in unit tests
106239 print (f"::set-output name={ name } ::{ val } " )
107240
108241
@@ -123,7 +256,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123256 return bool (re .match (pattern , model_name ))
124257
125258
126- def get_benchmark_configs () -> Dict [str , Dict ]:
259+ def get_benchmark_configs () -> Dict [str , Dict ]: # noqa: C901
127260 """
128261 Gather benchmark configurations for a given set of models on the target operating system and devices.
129262
@@ -153,48 +286,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153286 }
154287 """
155288 args = parse_args ()
156- target_os = args .os
157289 devices = args .devices
158290 models = args .models
291+ target_os = args .os
292+ target_configs = args .configs
159293
160294 benchmark_configs = {"include" : []}
161295
162296 for model_name in models :
163297 configs = []
164- if is_valid_huggingface_model_id (model_name ):
165- if model_name .startswith ("meta-llama/" ):
166- # LLaMA models
167- repo_name = model_name .split ("meta-llama/" )[1 ]
168- if "qlora" in repo_name .lower ():
169- configs .append ("llama3_qlora" )
170- elif "spinquant" in repo_name .lower ():
171- configs .append ("llama3_spinquant" )
172- else :
173- configs .append ("llama3_fb16" )
174- configs .extend (
175- [
176- config
177- for config in BENCHMARK_CONFIGS .get (target_os , [])
178- if config .startswith ("llama" )
179- ]
298+ configs .extend (generate_compatible_configs (model_name , target_os ))
299+ print (f"Discovered all supported configs for model '{ model_name } ': { configs } " )
300+ if target_configs is not None :
301+ for config in target_configs :
302+ if config not in configs :
303+ raise Exception (
304+ f"Unsupported config '{ config } ' for model '{ model_name } ' on '{ target_os } '. Skipped.\n "
305+ f"Supported configs are: { configs } "
180306 )
181- else :
182- # Non-LLaMA models
183- configs .append ("hf_xnnpack_fp32" )
184- elif model_name in MODEL_NAME_TO_MODEL :
185- # ExecuTorch in-tree non-GenAI models
186- configs .append ("xnnpack_q8" )
187- configs .extend (
188- [
189- config
190- for config in BENCHMARK_CONFIGS .get (target_os , [])
191- if not config .startswith ("llama" )
192- ]
193- )
194- else :
195- # Skip unknown models with a warning
196- logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
197- continue
307+ configs = target_configs
308+ print (f"Using provided configs { configs } for model '{ model_name } '" )
198309
199310 # Add configurations for each valid device
200311 for device in devices :
0 commit comments