55# This source code is licensed under the BSD-style license found in the
66# LICENSE file in the root directory of this source tree.
77
8+ import glob
89import json
910import logging
1011import os
@@ -307,9 +308,33 @@ def extract_job_id(artifacts_filename: str) -> int:
307308 return int (m .group ("job_id" ))
308309
309310
311+ def read_benchmark_configs (benchmark_configs : str ) -> Dict [str , Dict [str , str ]]:
312+ """
313+ Read all the benchmark configs that we can find
314+ """
315+ benchmark_configs = {}
316+
317+ for file in glob .glob (f"{ benchmark_configs } /*.json" ):
318+ filename = os .path .basename (file )
319+ with open (file ) as f :
320+ try :
321+ benchmark_configs [filename ] = json .load (f )
322+ except json .JSONDecodeError as e :
323+ warning (f"Fail to load benchmark config { file } : { e } " )
324+
325+ return benchmark_configs
326+
327+
328+ def get_benchmark_configs (benchmark_configs : Dict [str , Dict [str , str ]]) -> str :
329+ """
330+ Get the correct benchmark config for this benchmark run
331+ """
332+
333+
310334def transform (
311335 app_type : str ,
312336 benchmark_results : List ,
337+ benchmark_configs : Dict [str , Dict [str , str ]],
313338 repo : str ,
314339 head_branch : str ,
315340 workflow_name : str ,
@@ -359,11 +384,6 @@ def transform(
359384 for r in benchmark_results
360385 ]
361386 elif schema_version == "v3" :
362- quantization = (
363- r ["benchmarkModel" ]["quantization" ]
364- if r ["benchmarkModel" ]["quantization" ]
365- else "unknown"
366- )
367387 # From https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database
368388 return [
369389 {
@@ -373,6 +393,7 @@ def transform(
373393 "dtype" : quantization ,
374394 "extra_info" : {
375395 "app_type" : app_type ,
396+ "benchmark_configs" :
376397 },
377398 },
378399 "model" : {
@@ -412,6 +433,7 @@ def main() -> None:
412433 "v2" : [],
413434 "v3" : [],
414435 }
436+ benchmark_configs = read_benchmark_configs (args .benchmark_configs )
415437
416438 with open (args .artifacts ) as f :
417439 for artifact in json .load (f ):
@@ -442,6 +464,7 @@ def main() -> None:
442464 results = transform (
443465 app_type ,
444466 benchmark_results ,
467+ benchmark_configs ,
445468 args .repo ,
446469 args .head_branch ,
447470 args .workflow_name ,
0 commit comments