55import re
66import os
77
8+
89class AccuracyCheck (BaseCheck ):
9- def __init__ (self , log , path , config : Config , submission_logs : SubmissionLogs ):
10+ def __init__ (
11+ self , log , path , config : Config , submission_logs : SubmissionLogs
12+ ):
1013 super ().__init__ (log , path )
1114 self .name = "accuracy checks"
1215 self .submission_logs = submission_logs
@@ -15,9 +18,12 @@ def __init__(self, log, path, config: Config, submission_logs: SubmissionLogs):
1518 self .accuracy_json = self . submission_logs .accuracy_json
1619 self .config = config
1720 self .model = self .submission_logs .loader_data .get ("benchmark" , "" )
18- self .model_mapping = self .submission_logs .loader_data .get ("model_mapping" , {})
19- self .model = self .config .get_mlperf_model (self .model , self .model_mapping )
20- self .scenario_fixed = self .submission_logs .loader_data .get ("scenario" , "" )
21+ self .model_mapping = self .submission_logs .loader_data .get (
22+ "model_mapping" , {})
23+ self .model = self .config .get_mlperf_model (
24+ self .model , self .model_mapping )
25+ self .scenario_fixed = self .submission_logs .loader_data .get (
26+ "scenario" , "" )
2127 self .scenario = self .mlperf_log ["effective_scenario" ]
2228 self .division = self .submission_logs .loader_data .get ("division" , "" )
2329 self .setup_checks ()
@@ -97,7 +103,7 @@ def accuracy_result_check(self):
97103 if self .division .lower () == "open" :
98104 return True
99105 return is_valid
100-
106+
101107 def accuracy_json_check (self ):
102108 if not os .path .exists (self .accuracy_json ):
103109 self .log .error ("%s is missing" , self .accuracy_json )
@@ -107,7 +113,7 @@ def accuracy_json_check(self):
107113 self .log .error ("%s is not truncated" , self .accuracy_json )
108114 return False
109115 return True
110-
116+
111117 def loadgen_errors_check (self ):
112118 if self .mlperf_log .has_error ():
113119 if self .config .ignore_uncommited :
@@ -125,7 +131,7 @@ def loadgen_errors_check(self):
125131 )
126132 return False
127133 return True
128-
134+
129135 def dataset_check (self ):
130136 if self .config .skip_dataset_size_check :
131137 self .log .info (
@@ -139,4 +145,4 @@ def dataset_check(self):
139145 "%s accurcy run does not cover all dataset, accuracy samples: %s, dataset size: %s" , self .path , qsl_total_count , expected_qsl_total_count
140146 )
141147 return False
142- return True
148+ return True
0 commit comments