99import logging
1010import os
1111import re
12- from typing import Any , Dict
12+ import sys
13+ from typing import Any , Dict , List , NamedTuple
1314
15+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
1416from examples .models import MODEL_NAME_TO_MODEL
1517
1618
4547}
4648
4749
50+ class DisabledConfig (NamedTuple ):
51+ config_name : str
52+ github_issue : str # Link to the GitHub issue
53+
54+
55+ # Updated DISABLED_CONFIGS
56+ DISABLED_CONFIGS : Dict [str , List [DisabledConfig ]] = {
57+ "resnet50" : [
58+ DisabledConfig (
59+ config_name = "qnn_q8" ,
60+ github_issue = "https://github.com/pytorch/executorch/issues/7892" ,
61+ ),
62+ ],
63+ "w2l" : [
64+ DisabledConfig (
65+ config_name = "qnn_q8" ,
66+ github_issue = "https://github.com/pytorch/executorch/issues/7634" ,
67+ ),
68+ ],
69+ "mobilebert" : [
70+ DisabledConfig (
71+ config_name = "mps" ,
72+ github_issue = "https://github.com/pytorch/executorch/issues/7904" ,
73+ ),
74+ DisabledConfig (
75+ config_name = "qnn_q8" ,
76+ github_issue = "https://github.com/pytorch/executorch/issues/7946" ,
77+ ),
78+ ],
79+ "edsr" : [
80+ DisabledConfig (
81+ config_name = "mps" ,
82+ github_issue = "https://github.com/pytorch/executorch/issues/7905" ,
83+ ),
84+ ],
85+ "llama" : [
86+ DisabledConfig (
87+ config_name = "mps" ,
88+ github_issue = "https://github.com/pytorch/executorch/issues/7907" ,
89+ ),
90+ ],
91+ }
92+
93+
94+ def extract_all_configs (data , target_os = None ):
95+ if isinstance (data , dict ):
96+ # If target_os is specified, include "xplat" and the specified branch
97+ include_branches = {"xplat" , target_os } if target_os else data .keys ()
98+ return [
99+ v
100+ for key , value in data .items ()
101+ if key in include_branches
102+ for v in extract_all_configs (value , target_os )
103+ ]
104+ elif isinstance (data , list ):
105+ return [v for item in data for v in extract_all_configs (item , target_os )]
106+ else :
107+ return [data ]
108+
109+
110+ def generate_compatible_configs (model_name : str , target_os = None ) -> List [str ]:
111+ """
112+ Generate a list of compatible benchmark configurations for a given model name and target OS.
113+
114+ Args:
115+ model_name (str): The name of the model to generate configurations for.
116+ target_os (Optional[str]): The target operating system (e.g., 'android', 'ios').
117+
118+ Returns:
119+ List[str]: A list of compatible benchmark configurations.
120+
121+ Raises:
122+ None
123+
124+ Example:
125+ generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane']
126+ """
127+ configs = []
128+ if is_valid_huggingface_model_id (model_name ):
129+ if model_name .startswith ("meta-llama/" ):
130+ # LLaMA models
131+ repo_name = model_name .split ("meta-llama/" )[1 ]
132+ if "qlora" in repo_name .lower ():
133+ configs .append ("llama3_qlora" )
134+ elif "spinquant" in repo_name .lower ():
135+ configs .append ("llama3_spinquant" )
136+ else :
137+ configs .append ("llama3_fb16" )
138+ configs .extend (
139+ [
140+ config
141+ for config in BENCHMARK_CONFIGS .get (target_os , [])
142+ if config .startswith ("llama" )
143+ ]
144+ )
145+ else :
146+ # Non-LLaMA models
147+ configs .append ("hf_xnnpack_fp32" )
148+ elif model_name in MODEL_NAME_TO_MODEL :
149+ # ExecuTorch in-tree non-GenAI models
150+ configs .append ("xnnpack_q8" )
151+ if target_os != "xplat" :
152+ # Add OS-specific configs
153+ configs .extend (
154+ [
155+ config
156+ for config in BENCHMARK_CONFIGS .get (target_os , [])
157+ if not config .startswith ("llama" )
158+ ]
159+ )
160+ else :
161+ # Skip unknown models with a warning
162+ logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
163+
164+ # Remove disabled configs for the given model
165+ disabled_configs = DISABLED_CONFIGS .get (model_name , [])
166+ disabled_config_names = {disabled .config_name for disabled in disabled_configs }
167+ for disabled in disabled_configs :
168+ print (
169+ f"Excluding disabled config: '{ disabled .config_name } ' for model '{ model_name } ' on '{ target_os } '. Linked GitHub issue: { disabled .github_issue } "
170+ )
171+ configs = [config for config in configs if config not in disabled_config_names ]
172+ return configs
173+
174+
48175def parse_args () -> Any :
49176 """
50177 Parse command-line arguments.
@@ -82,6 +209,11 @@ def comma_separated(value: str):
82209 type = comma_separated , # Use the custom parser for comma-separated values
83210 help = f"Comma-separated device names. Available devices: { list (DEVICE_POOLS .keys ())} " ,
84211 )
212+ parser .add_argument (
213+ "--configs" ,
214+ type = comma_separated , # Use the custom parser for comma-separated values
215+ help = f"Comma-separated benchmark configs. Available configs: { extract_all_configs (BENCHMARK_CONFIGS )} " ,
216+ )
85217
86218 return parser .parse_args ()
87219
@@ -98,11 +230,16 @@ def set_output(name: str, val: Any) -> None:
98230 set_output("benchmark_configs", {"include": [...]})
99231 """
100232
101- if os .getenv ("GITHUB_OUTPUT" ):
102- print (f"Setting { val } to GitHub output" )
103- with open (str (os .getenv ("GITHUB_OUTPUT" )), "a" ) as env :
104- print (f"{ name } ={ val } " , file = env )
105- else :
233+ github_output = os .getenv ("GITHUB_OUTPUT" )
234+ if not github_output :
235+ print (f"::set-output name={ name } ::{ val } " )
236+ return
237+
238+ try :
239+ with open (github_output , "a" ) as env :
240+ env .write (f"{ name } ={ val } \n " )
241+ except PermissionError :
242+ # Fall back to printing in case of permission error in unit tests
106243 print (f"::set-output name={ name } ::{ val } " )
107244
108245
@@ -123,7 +260,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123260 return bool (re .match (pattern , model_name ))
124261
125262
126- def get_benchmark_configs () -> Dict [str , Dict ]:
263+ def get_benchmark_configs () -> Dict [str , Dict ]: # noqa: C901
127264 """
128265 Gather benchmark configurations for a given set of models on the target operating system and devices.
129266
@@ -153,48 +290,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153290 }
154291 """
155292 args = parse_args ()
156- target_os = args .os
157293 devices = args .devices
158294 models = args .models
295+ target_os = args .os
296+ target_configs = args .configs
159297
160298 benchmark_configs = {"include" : []}
161299
162300 for model_name in models :
163301 configs = []
164- if is_valid_huggingface_model_id (model_name ):
165- if model_name .startswith ("meta-llama/" ):
166- # LLaMA models
167- repo_name = model_name .split ("meta-llama/" )[1 ]
168- if "qlora" in repo_name .lower ():
169- configs .append ("llama3_qlora" )
170- elif "spinquant" in repo_name .lower ():
171- configs .append ("llama3_spinquant" )
172- else :
173- configs .append ("llama3_fb16" )
174- configs .extend (
175- [
176- config
177- for config in BENCHMARK_CONFIGS .get (target_os , [])
178- if config .startswith ("llama" )
179- ]
302+ configs .extend (generate_compatible_configs (model_name , target_os ))
303+ print (f"Discovered all supported configs for model '{ model_name } ': { configs } " )
304+ if target_configs is not None :
305+ for config in target_configs :
306+ if config not in configs :
307+ raise Exception (
308+ f"Unsupported config '{ config } ' for model '{ model_name } ' on '{ target_os } '. Skipped.\n "
309+ f"Supported configs are: { configs } "
180310 )
181- else :
182- # Non-LLaMA models
183- configs .append ("hf_xnnpack_fp32" )
184- elif model_name in MODEL_NAME_TO_MODEL :
185- # ExecuTorch in-tree non-GenAI models
186- configs .append ("xnnpack_q8" )
187- configs .extend (
188- [
189- config
190- for config in BENCHMARK_CONFIGS .get (target_os , [])
191- if not config .startswith ("llama" )
192- ]
193- )
194- else :
195- # Skip unknown models with a warning
196- logging .warning (f"Unknown or invalid model name '{ model_name } '. Skipping." )
197- continue
311+ configs = target_configs
312+ print (f"Using provided configs { configs } for model '{ model_name } '" )
198313
199314 # Add configurations for each valid device
200315 for device in devices :
0 commit comments