1515
1616def output_file_schema (target ):
1717 current_version = "1.0.0"
18- file = schema_dir + f"/T4/{ current_version } /{ target } -schema.json"
19- with open (file , 'r' ) as fh :
18+ output_file = schema_dir + f"/T4/{ current_version } /{ target } -schema.json"
19+ with open (output_file , 'r' ) as fh :
2020 json_string = json .load (fh )
2121 return current_version , json_string
2222
2323
24+ def get_configuration_validity (objective ) -> str :
25+ """ Convert internal Kernel Tuner error to string """
26+ if not isinstance (objective , util .ErrorConfig ):
27+ return "correct"
28+ else :
29+ if isinstance (objective , util .CompilationFailedConfig ):
30+ return "compile"
31+ elif isinstance (objective , util .RuntimeFailedConfig ):
32+ return "runtime"
33+ else :
34+ return "constraints"
35+
36+
2437def store_output_file (output_filename , results , tune_params , objective = "time" ):
2538 if output_filename [- 5 :] != ".json" :
2639 output_filename += ".json"
2740
28- timing_keys = ["compile_time" , "benchmark_time" , "framework_time" , "strategy_time" , "verification_time" ]
29- not_measurement_keys = list (tune_params .keys ()) + timing_keys + ["timestamp" ] + ["times" ]
41+ timing_keys = [
42+ "compile_time" , "benchmark_time" , "framework_time" , "strategy_time" ,
43+ "verification_time"
44+ ]
45+ not_measurement_keys = list (
46+ tune_params .keys ()) + timing_keys + ["timestamp" ] + ["times" ]
3047
3148 output_data = []
3249
@@ -35,8 +52,10 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
3552 out = {}
3653
3754 out ["timestamp" ] = result ["timestamp" ]
38- out ["configuration" ] = { k : v
39- for k , v in result .items () if k in tune_params }
55+ out ["configuration" ] = {
56+ k : v
57+ for k , v in result .items () if k in tune_params
58+ }
4059
4160 # collect configuration specific timings
4261 timings = dict ()
@@ -49,15 +68,7 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
4968 out ["times" ] = timings
5069
5170 # encode the validity of the configuration
52- if not isinstance (result [objective ], util .ErrorConfig ):
53- out ["invalidity" ] = "correct"
54- else :
55- if isinstance (result [objective ], util .CompilationFailedConfig ):
56- out ["invalidity" ] = "compile"
57- elif isinstance (result [objective ], util .RuntimeFailedConfig ):
58- out ["invalidity" ] = "runtime"
59- else :
60- out ["invalidity" ] = "constraints"
71+ out ["invalidity" ] = get_configuration_validity (result [objective ])
6172
6273 # Kernel Tuner does not support producing results of configs that fail the correctness check
6374 # therefore correctness is always 1
@@ -66,11 +77,11 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
6677 # measurements gathers everything that was measured
6778 measurements = []
6879 for key , value in result .items ():
69- if not key in not_measurement_keys :
70- if key . startswith ( "time" ):
71- measurements . append ( dict (name = key , value = value , unit = "ms" ))
72- else :
73- measurements . append ( dict ( name = key , value = value , unit = "" ))
80+ if key not in not_measurement_keys :
81+ measurements . append (
82+ dict (name = key ,
83+ value = value ,
84+ unit = "ms" if key . startswith ( "time" ) else "" ))
7485 out ["measurements" ] = measurements
7586
7687 # objectives
@@ -105,12 +116,14 @@ def get_dependencies(package='kernel_tuner'):
105116
106117def get_device_query (target ):
107118 if target == "nvidia" :
108- nvidia_smi_out = subprocess .run (["nvidia-smi" , "--query" , "-x" ], capture_output = True )
119+ nvidia_smi_out = subprocess .run (["nvidia-smi" , "--query" , "-x" ],
120+ capture_output = True )
109121 nvidia_smi = xmltodict .parse (nvidia_smi_out .stdout )
110122 del nvidia_smi ["nvidia_smi_log" ]["gpu" ]["processes" ]
111123 return nvidia_smi
112124 elif target == "amd" :
113- rocm_smi_out = subprocess .run (["rocm-smi" , "--showallinfo" , "--json" ], capture_output = True )
125+ rocm_smi_out = subprocess .run (["rocm-smi" , "--showallinfo" , "--json" ],
126+ capture_output = True )
114127 return json .loads (rocm_smi_out .stdout )
115128 else :
116129 raise ValueError ("get_device_query target not supported" )
@@ -128,7 +141,8 @@ def store_metadata_file(metadata_filename, target="nvidia"):
128141 # only works if nvidia-smi (for NVIDIA) or rocm-smi (for AMD) is present, raises FileNotFoundError when not present
129142 device_query = get_device_query (target )
130143
131- metadata ["environment" ] = dict (device_query = device_query , requirements = get_dependencies ())
144+ metadata ["environment" ] = dict (device_query = device_query ,
145+ requirements = get_dependencies ())
132146
133147 # write metadata to JSON file
134148 version , _ = output_file_schema ("metadata" )
0 commit comments