55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
8+ import glob
89import json
910import logging
1011import os
2223
2324BENCHMARK_RESULTS_FILENAME = "benchmark_results.json"
2425ARTIFACTS_FILENAME_REGEX = re .compile (r"(android|ios)-artifacts-(?P<job_id>\d+).json" )
26+ BENCHMARK_CONFIG_REGEX = re .compile (r"The benchmark config is (?P<benchmark_config>.+)" )
2527
2628# iOS-related regexes and variables
2729IOS_TEST_SPEC_REGEX = re .compile (
@@ -51,7 +53,7 @@ def __call__(
5153 parser .error (f"{ values } is not a valid JSON file (*.json)" )
5254
5355
54- class ValidateOutputDir (Action ):
56+ class ValidateDir (Action ):
5557 def __call__ (
5658 self ,
5759 parser : ArgumentParser ,
@@ -81,7 +83,7 @@ def parse_args() -> Any:
8183 "--output-dir" ,
8284 type = str ,
8385 required = True ,
84- action = ValidateOutputDir ,
86+ action = ValidateDir ,
8587 help = "the directory to keep the benchmark results" ,
8688 )
8789 parser .add_argument (
@@ -114,6 +116,13 @@ def parse_args() -> Any:
114116 required = True ,
115117 help = "which retry of the workflow this is" ,
116118 )
119+ parser .add_argument (
120+ "--benchmark-configs" ,
121+ type = str ,
122+ required = True ,
123+ action = ValidateDir ,
124+ help = "the directory to keep the benchmark configs" ,
125+ )
117126
118127 return parser .parse_args ()
119128
@@ -300,9 +309,60 @@ def extract_job_id(artifacts_filename: str) -> int:
300309 return int (m .group ("job_id" ))
301310
302311
312+ def read_all_benchmark_configs () -> Dict [str , Dict [str , str ]]:
313+ """
314+ Read all the benchmark configs that we can find
315+ """
316+ benchmark_configs = {}
317+
318+ for file in glob .glob (f"{ benchmark_configs } /*.json" ):
319+ filename = os .path .basename (file )
320+ with open (file ) as f :
321+ try :
322+ benchmark_configs [filename ] = json .load (f )
323+ except json .JSONDecodeError as e :
324+ warning (f"Fail to load benchmark config { file } : { e } " )
325+
326+ return benchmark_configs
327+
328+
329+ def read_benchmark_config (
330+ artifact_s3_url : str , benchmark_configs_dir : str
331+ ) -> Dict [str , str ]:
332+ """
333+ Get the correct benchmark config for this benchmark run
334+ """
335+ try :
336+ with request .urlopen (artifact_s3_url ) as data :
337+ for line in data .read ().decode ("utf8" ).splitlines ():
338+ m = BENCHMARK_CONFIG_REGEX .match (line )
339+ if not m :
340+ continue
341+
342+ benchmark_config = m .group ("benchmark_config" )
343+ filename = os .path .join (
344+ benchmark_configs_dir , f"{ benchmark_config } .json"
345+ )
346+
347+ if not os .path .exists (filename ):
348+ warning (f"There is no benchmark config { filename } " )
349+ continue
350+
351+ with open (filename ) as f :
352+ try :
353+ return json .load (f )
354+ except json .JSONDecodeError as e :
355+ warning (f"Fail to load benchmark config { filename } : { e } " )
356+ except error .HTTPError :
357+ warning (f"Fail to read the test spec output at { artifact_s3_url } " )
358+
359+ return {}
360+
361+
303362def transform (
304363 app_type : str ,
305364 benchmark_results : List ,
365+ benchmark_config : Dict [str , str ],
306366 repo : str ,
307367 head_branch : str ,
308368 workflow_name : str ,
@@ -352,29 +412,25 @@ def transform(
352412 for r in benchmark_results
353413 ]
354414 elif schema_version == "v3" :
355- quantization = (
356- r ["benchmarkModel" ]["quantization" ]
357- if r ["benchmarkModel" ]["quantization" ]
358- else "unknown"
359- )
415+ v3_benchmark_results = []
360416 # From https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
361417 return [
362418 {
363419 "benchmark" : {
364420 "name" : "ExecuTorch" ,
365421 "mode" : "inference" ,
366- "dtype" : quantization ,
367422 "extra_info" : {
368423 "app_type" : app_type ,
424+ # Just keep a copy of the benchmark config here
425+ "benchmark_config" : json .dumps (benchmark_config ),
369426 },
370427 },
371428 "model" : {
372- "name" : r ["benchmarkModel" ]["name" ],
429+ "name" : benchmark_config . get ( "model" , r ["benchmarkModel" ]["name" ]) ,
373430 "type" : "OSS model" ,
374- "backend" : r ["benchmarkModel" ].get ("backend" , "" ),
375- "extra_info" : {
376- "quantization" : quantization ,
377- },
431+ "backend" : benchmark_config .get (
432+ "config" , r ["benchmarkModel" ].get ("backend" , "" )
433+ ),
378434 },
379435 "metric" : {
380436 "name" : r ["metric" ],
@@ -405,6 +461,7 @@ def main() -> None:
405461 "v2" : [],
406462 "v3" : [],
407463 }
464+ benchmark_config = {}
408465
409466 with open (args .artifacts ) as f :
410467 for artifact in json .load (f ):
@@ -420,6 +477,11 @@ def main() -> None:
420477 artifact_type = artifact ["type" ]
421478 artifact_s3_url = artifact ["s3_url" ]
422479
480+ if artifact_type == "TESTSPEC_OUTPUT" :
481+ benchmark_config = read_benchmark_config (
482+ artifact_s3_url , args .benchmark_configs
483+ )
484+
423485 if app_type == "ANDROID_APP" :
424486 benchmark_results = extract_android_benchmark_results (
425487 job_name , artifact_type , artifact_s3_url
@@ -435,6 +497,7 @@ def main() -> None:
435497 results = transform (
436498 app_type ,
437499 benchmark_results ,
500+ benchmark_config ,
438501 args .repo ,
439502 args .head_branch ,
440503 args .workflow_name ,
0 commit comments