Skip to content

Commit a5c3189

Browse files
committed
Test for store_output_file and store_metadata_file, fixed an issue where the filename would not be extended properly
1 parent 2e389de commit a5c3189

File tree

2 files changed

+76
-23
lines changed

2 files changed

+76
-23
lines changed

kernel_tuner/file_utils.py

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -10,25 +10,27 @@
1010

1111
from kernel_tuner import util
1212

13-
1413
schema_dir = os.path.dirname(os.path.realpath(__file__)) + "/schema"
1514

15+
1616
def output_file_schema(target):
1717
current_version = "1.0.0"
1818
file = schema_dir + f"/T4/{current_version}/{target}-schema.json"
1919
with open(file, 'r') as fh:
2020
json_string = json.load(fh)
2121
return current_version, json_string
2222

23-
def store_output_file(output_filename, results, tune_params, objective="time"):
2423

25-
if output_filename[-5] != ".json":
24+
def store_output_file(output_filename, results, tune_params, objective="time"):
25+
if output_filename[-5:] != ".json":
2626
output_filename += ".json"
2727

28-
print(f"Producing {output_filename}")
29-
30-
timing_keys = ["compile_time", "benchmark_time", "framework_time", "strategy_time", "verification_time"]
31-
not_measurement_keys = list(tune_params.keys()) + timing_keys + ["timestamp"]
28+
timing_keys = [
29+
"compile_time", "benchmark_time", "framework_time", "strategy_time",
30+
"verification_time"
31+
]
32+
not_measurement_keys = list(
33+
tune_params.keys()) + timing_keys + ["timestamp"]
3234

3335
output_data = []
3436

@@ -37,7 +39,10 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
3739
out = {}
3840

3941
out["timestamp"] = result["timestamp"]
40-
out["configuration"] = {k:v for k, v in result.items() if k in tune_params}
42+
out["configuration"] = {
43+
k: v
44+
for k, v in result.items() if k in tune_params
45+
}
4146

4247
# collect configuration specific timings
4348
timings = dict()
@@ -82,13 +87,10 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
8287
# append to output
8388
output_data.append(out)
8489

85-
# validate schema
86-
version, schema = output_file_schema("results")
87-
output_json = dict(results=output_data, schema_version=version)
88-
validate(output_json, schema=schema)
89-
9090
# write output_data to a JSON file
91-
with open(output_filename, 'w') as fh:
91+
version, _ = output_file_schema("results")
92+
output_json = dict(results=output_data, schema_version=version)
93+
with open(output_filename, 'w+') as fh:
9294
json.dump(output_json, fh)
9395

9496

@@ -108,32 +110,36 @@ def get_dependencies(package='kernel_tuner'):
108110

109111
def get_device_query(target):
110112
if target == "nvidia":
111-
nvidia_smi_out = subprocess.run(["nvidia-smi", "--query", "-x"], capture_output=True)
113+
nvidia_smi_out = subprocess.run(["nvidia-smi", "--query", "-x"],
114+
capture_output=True)
112115
nvidia_smi = xmltodict.parse(nvidia_smi_out.stdout)
113116
del nvidia_smi["nvidia_smi_log"]["gpu"]["processes"]
114117
return nvidia_smi
115118
elif target == "amd":
116-
rocm_smi_out = subprocess.run(["rocm-smi", "--showallinfo", "--json"], capture_output=True)
119+
rocm_smi_out = subprocess.run(["rocm-smi", "--showallinfo", "--json"],
120+
capture_output=True)
117121
return json.loads(rocm_smi_out.stdout)
118122
else:
119123
raise ValueError("get_device_query target not supported")
120124

121125

122126
def store_metadata_file(metadata_filename, target="nvidia"):
127+
if metadata_filename[-5:] != ".json":
128+
metadata_filename += ".json"
123129
metadata = {}
124130

131+
# lshw only works on Linux, this intentionally raises a FileNotFoundError when ran on systems that do not have it
125132
lshw_out = subprocess.run(["lshw", "-json"], capture_output=True)
126133
metadata["hardware"] = dict(lshw=json.loads(lshw_out.stdout))
127134

135+
# only works if nvidia-smi (for NVIDIA) or rocm-smi (for AMD) is present, raises FileNotFoundError when not present
128136
device_query = get_device_query(target)
129137

130-
metadata["environment"] = dict(device_query=device_query, requirements=get_dependencies())
131-
132-
# validate schema
133-
version, schema = output_file_schema("metadata")
134-
metadata_json = dict(metadata=metadata, schema_version=version)
135-
validate(metadata_json, schema=schema)
138+
metadata["environment"] = dict(device_query=device_query,
139+
requirements=get_dependencies())
136140

137141
# write metadata to JSON file
138-
with open(metadata_filename, 'w') as fh:
142+
version, _ = output_file_schema("metadata")
143+
metadata_json = dict(metadata=metadata, schema_version=version)
144+
with open(metadata_filename, 'w+') as fh:
139145
json.dump(metadata_json, fh, indent=" ")

test/test_file_utils.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
from kernel_tuner.file_utils import store_output_file, store_metadata_file, output_file_schema, validate
2+
from .test_integration import fake_results
3+
import pytest
4+
import json
5+
import os
6+
7+
8+
def test_store_output_file(fake_results):
9+
# setup variables
10+
filename = "test_output_file.json"
11+
_, _, tune_params, _, _, results, _ = fake_results
12+
13+
# run store_output_file
14+
store_output_file(filename, results, tune_params)
15+
16+
# retrieve output file
17+
_, schema = output_file_schema("results")
18+
with open(filename) as json_file:
19+
output_json = json.load(json_file)
20+
21+
# validate
22+
validate(output_json, schema=schema)
23+
24+
# clean up
25+
os.remove(filename)
26+
27+
28+
def test_store_metadata_file():
29+
# setup variables
30+
filename = "test_metadata_file.json"
31+
32+
# run store_metadata_file
33+
try:
34+
store_metadata_file(filename, target="nvidia")
35+
except FileNotFoundError:
36+
pytest.skip("'lshw' or 'nvidia-smi' not present on this system")
37+
38+
# retrieve metadata file
39+
_, schema = output_file_schema("metadata")
40+
with open(filename) as json_file:
41+
metadata_json = json.load(json_file)
42+
43+
# validate
44+
validate(metadata_json, schema=schema)
45+
46+
# clean up
47+
os.remove(filename)

0 commit comments

Comments
 (0)