1010
1111from kernel_tuner import util
1212
13-
1413schema_dir = os .path .dirname (os .path .realpath (__file__ )) + "/schema"
1514
15+
1616def output_file_schema (target ):
1717 current_version = "1.0.0"
1818 file = schema_dir + f"/T4/{ current_version } /{ target } -schema.json"
1919 with open (file , 'r' ) as fh :
2020 json_string = json .load (fh )
2121 return current_version , json_string
2222
23- def store_output_file (output_filename , results , tune_params , objective = "time" ):
2423
25- if output_filename [- 5 ] != ".json" :
24+ def store_output_file (output_filename , results , tune_params , objective = "time" ):
25+ if output_filename [- 5 :] != ".json" :
2626 output_filename += ".json"
2727
28- print (f"Producing { output_filename } " )
29-
30- timing_keys = ["compile_time" , "benchmark_time" , "framework_time" , "strategy_time" , "verification_time" ]
31- not_measurement_keys = list (tune_params .keys ()) + timing_keys + ["timestamp" ]
28+ timing_keys = [
29+ "compile_time" , "benchmark_time" , "framework_time" , "strategy_time" ,
30+ "verification_time"
31+ ]
32+ not_measurement_keys = list (
33+ tune_params .keys ()) + timing_keys + ["timestamp" ]
3234
3335 output_data = []
3436
@@ -37,7 +39,10 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
3739 out = {}
3840
3941 out ["timestamp" ] = result ["timestamp" ]
40- out ["configuration" ] = {k :v for k , v in result .items () if k in tune_params }
42+ out ["configuration" ] = {
43+ k : v
44+ for k , v in result .items () if k in tune_params
45+ }
4146
4247 # collect configuration specific timings
4348 timings = dict ()
@@ -82,13 +87,10 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
8287 # append to output
8388 output_data .append (out )
8489
85- # validate schema
86- version , schema = output_file_schema ("results" )
87- output_json = dict (results = output_data , schema_version = version )
88- validate (output_json , schema = schema )
89-
9090 # write output_data to a JSON file
91- with open (output_filename , 'w' ) as fh :
91+ version , _ = output_file_schema ("results" )
92+ output_json = dict (results = output_data , schema_version = version )
93+ with open (output_filename , 'w+' ) as fh :
9294 json .dump (output_json , fh )
9395
9496
@@ -108,32 +110,36 @@ def get_dependencies(package='kernel_tuner'):
108110
109111def get_device_query (target ):
110112 if target == "nvidia" :
111- nvidia_smi_out = subprocess .run (["nvidia-smi" , "--query" , "-x" ], capture_output = True )
113+ nvidia_smi_out = subprocess .run (["nvidia-smi" , "--query" , "-x" ],
114+ capture_output = True )
112115 nvidia_smi = xmltodict .parse (nvidia_smi_out .stdout )
113116 del nvidia_smi ["nvidia_smi_log" ]["gpu" ]["processes" ]
114117 return nvidia_smi
115118 elif target == "amd" :
116- rocm_smi_out = subprocess .run (["rocm-smi" , "--showallinfo" , "--json" ], capture_output = True )
119+ rocm_smi_out = subprocess .run (["rocm-smi" , "--showallinfo" , "--json" ],
120+ capture_output = True )
117121 return json .loads (rocm_smi_out .stdout )
118122 else :
119123 raise ValueError ("get_device_query target not supported" )
120124
121125
122126def store_metadata_file (metadata_filename , target = "nvidia" ):
127+ if metadata_filename [- 5 :] != ".json" :
128+ metadata_filename += ".json"
123129 metadata = {}
124130
131+ # lshw only works on Linux, this intentionally raises a FileNotFoundError when ran on systems that do not have it
125132 lshw_out = subprocess .run (["lshw" , "-json" ], capture_output = True )
126133 metadata ["hardware" ] = dict (lshw = json .loads (lshw_out .stdout ))
127134
135+ # only works if nvidia-smi (for NVIDIA) or rocm-smi (for AMD) is present, raises FileNotFoundError when not present
128136 device_query = get_device_query (target )
129137
130- metadata ["environment" ] = dict (device_query = device_query , requirements = get_dependencies ())
131-
132- # validate schema
133- version , schema = output_file_schema ("metadata" )
134- metadata_json = dict (metadata = metadata , schema_version = version )
135- validate (metadata_json , schema = schema )
138+ metadata ["environment" ] = dict (device_query = device_query ,
139+ requirements = get_dependencies ())
136140
137141 # write metadata to JSON file
138- with open (metadata_filename , 'w' ) as fh :
142+ version , _ = output_file_schema ("metadata" )
143+ metadata_json = dict (metadata = metadata , schema_version = version )
144+ with open (metadata_filename , 'w+' ) as fh :
139145 json .dump (metadata_json , fh , indent = " " )
0 commit comments