99import  logging 
1010import  os 
1111import  re 
12- from  typing  import  Any , Dict 
12+ import  sys 
13+ from  typing  import  Any , Dict , List 
1314
15+ sys .path .append (os .path .abspath (os .path .join (os .path .dirname (__file__ ), "../.." )))
1416from  examples .models  import  MODEL_NAME_TO_MODEL 
1517
1618
4547}
4648
4749
50+ def  extract_all_configs (data , target_os = None ):
51+     if  isinstance (data , dict ):
52+         # If target_os is specified, include "xplat" and the specified branch 
53+         include_branches  =  {"xplat" , target_os } if  target_os  else  data .keys ()
54+         return  [
55+             v 
56+             for  key , value  in  data .items ()
57+             if  key  in  include_branches 
58+             for  v  in  extract_all_configs (value , target_os )
59+         ]
60+     elif  isinstance (data , list ):
61+         return  [v  for  item  in  data  for  v  in  extract_all_configs (item , target_os )]
62+     else :
63+         return  [data ]
64+ 
65+ 
66+ def  generate_compatible_configs (model_name : str , target_os = None ) ->  List [str ]:
67+     """ 
68+     Generate a list of compatible benchmark configurations for a given model name and target OS. 
69+ 
70+     Args: 
71+         model_name (str): The name of the model to generate configurations for. 
72+         target_os (Optional[str]): The target operating system (e.g., 'android', 'ios'). 
73+ 
74+     Returns: 
75+         List[str]: A list of compatible benchmark configurations. 
76+ 
77+     Raises: 
78+         None 
79+ 
80+     Example: 
81+         generate_compatible_configs('meta-llama/Llama-3.2-1B', 'ios') -> ['llama3_fb16', 'llama3_coreml_ane'] 
82+     """ 
83+     configs  =  []
84+     if  is_valid_huggingface_model_id (model_name ):
85+         if  model_name .startswith ("meta-llama/" ):
86+             # LLaMA models 
87+             repo_name  =  model_name .split ("meta-llama/" )[1 ]
88+             if  "qlora"  in  repo_name .lower ():
89+                 configs .append ("llama3_qlora" )
90+             elif  "spinquant"  in  repo_name .lower ():
91+                 configs .append ("llama3_spinquant" )
92+             else :
93+                 configs .append ("llama3_fb16" )
94+                 configs .extend (
95+                     [
96+                         config 
97+                         for  config  in  BENCHMARK_CONFIGS .get (target_os , [])
98+                         if  config .startswith ("llama" )
99+                     ]
100+                 )
101+         else :
102+             # Non-LLaMA models 
103+             configs .append ("hf_xnnpack_fp32" )
104+     elif  model_name  in  MODEL_NAME_TO_MODEL :
105+         # ExecuTorch in-tree non-GenAI models 
106+         configs .append ("xnnpack_q8" )
107+         if  target_os  !=  "xplat" :
108+             # Add OS-specific configs 
109+             configs .extend (
110+                 [
111+                     config 
112+                     for  config  in  BENCHMARK_CONFIGS .get (target_os , [])
113+                     if  not  config .startswith ("llama" )
114+                 ]
115+             )
116+     else :
117+         # Skip unknown models with a warning 
118+         logging .warning (f"Unknown or invalid model name '{ model_name }  )
119+ 
120+     return  configs 
121+ 
122+ 
48123def  parse_args () ->  Any :
49124    """ 
50125    Parse command-line arguments. 
@@ -82,6 +157,11 @@ def comma_separated(value: str):
82157        type = comma_separated ,  # Use the custom parser for comma-separated values 
83158        help = f"Comma-separated device names. Available devices: { list (DEVICE_POOLS .keys ())}  ,
84159    )
160+     parser .add_argument (
161+         "--configs" ,
162+         type = comma_separated ,  # Use the custom parser for comma-separated values 
163+         help = f"Comma-separated benchmark configs. Available configs: { extract_all_configs (BENCHMARK_CONFIGS )}  ,
164+     )
85165
86166    return  parser .parse_args ()
87167
@@ -98,11 +178,16 @@ def set_output(name: str, val: Any) -> None:
98178        set_output("benchmark_configs", {"include": [...]}) 
99179    """ 
100180
101-     if  os .getenv ("GITHUB_OUTPUT" ):
102-         print (f"Setting { val }  )
103-         with  open (str (os .getenv ("GITHUB_OUTPUT" )), "a" ) as  env :
104-             print (f"{ name } { val }  , file = env )
105-     else :
181+     github_output  =  os .getenv ("GITHUB_OUTPUT" )
182+     if  not  github_output :
183+         print (f"::set-output name={ name } { val }  )
184+         return 
185+ 
186+     try :
187+         with  open (github_output , "a" ) as  env :
188+             env .write (f"{ name } { val } \n " )
189+     except  PermissionError :
190+         # Fall back to printing in case of permission error in unit tests 
106191        print (f"::set-output name={ name } { val }  )
107192
108193
@@ -123,7 +208,7 @@ def is_valid_huggingface_model_id(model_name: str) -> bool:
123208    return  bool (re .match (pattern , model_name ))
124209
125210
126- def  get_benchmark_configs () ->  Dict [str , Dict ]:
211+ def  get_benchmark_configs () ->  Dict [str , Dict ]:   # noqa: C901 
127212    """ 
128213    Gather benchmark configurations for a given set of models on the target operating system and devices. 
129214
@@ -153,48 +238,26 @@ def get_benchmark_configs() -> Dict[str, Dict]:
153238        } 
154239    """ 
155240    args  =  parse_args ()
156-     target_os  =  args .os 
157241    devices  =  args .devices 
158242    models  =  args .models 
243+     target_os  =  args .os 
244+     target_configs  =  args .configs 
159245
160246    benchmark_configs  =  {"include" : []}
161247
162248    for  model_name  in  models :
163249        configs  =  []
164-         if  is_valid_huggingface_model_id (model_name ):
165-             if  model_name .startswith ("meta-llama/" ):
166-                 # LLaMA models 
167-                 repo_name  =  model_name .split ("meta-llama/" )[1 ]
168-                 if  "qlora"  in  repo_name .lower ():
169-                     configs .append ("llama3_qlora" )
170-                 elif  "spinquant"  in  repo_name .lower ():
171-                     configs .append ("llama3_spinquant" )
172-                 else :
173-                     configs .append ("llama3_fb16" )
174-                     configs .extend (
175-                         [
176-                             config 
177-                             for  config  in  BENCHMARK_CONFIGS .get (target_os , [])
178-                             if  config .startswith ("llama" )
179-                         ]
250+         configs .extend (generate_compatible_configs (model_name , target_os ))
251+         print (f"Discovered all supported configs for model '{ model_name } { configs }  )
252+         if  target_configs  is  not None :
253+             for  config  in  target_configs :
254+                 if  config  not  in configs :
255+                     raise  Exception (
256+                         f"Unsupported config '{ config } { model_name } { target_os } \n " 
257+                         f"Supported configs are: { configs }  
180258                    )
181-             else :
182-                 # Non-LLaMA models 
183-                 configs .append ("hf_xnnpack_fp32" )
184-         elif  model_name  in  MODEL_NAME_TO_MODEL :
185-             # ExecuTorch in-tree non-GenAI models 
186-             configs .append ("xnnpack_q8" )
187-             configs .extend (
188-                 [
189-                     config 
190-                     for  config  in  BENCHMARK_CONFIGS .get (target_os , [])
191-                     if  not  config .startswith ("llama" )
192-                 ]
193-             )
194-         else :
195-             # Skip unknown models with a warning 
196-             logging .warning (f"Unknown or invalid model name '{ model_name }  )
197-             continue 
259+             configs  =  target_configs 
260+             print (f"Using provided configs { configs } { model_name }  )
198261
199262        # Add configurations for each valid device 
200263        for  device  in  devices :
0 commit comments