Skip to content

Commit 20dcff0

Browse files
committed
Improved code complexity and added docstring
1 parent 3e53074 commit 20dcff0

File tree

1 file changed

+37
-23
lines changed

1 file changed

+37
-23
lines changed

kernel_tuner/file_utils.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,35 @@
1515

1616
def output_file_schema(target):
1717
current_version = "1.0.0"
18-
file = schema_dir + f"/T4/{current_version}/{target}-schema.json"
19-
with open(file, 'r') as fh:
18+
output_file = schema_dir + f"/T4/{current_version}/{target}-schema.json"
19+
with open(output_file, 'r') as fh:
2020
json_string = json.load(fh)
2121
return current_version, json_string
2222

2323

24+
def get_configuration_validity(objective) -> str:
25+
""" Convert internal Kernel Tuner error to string """
26+
if not isinstance(objective, util.ErrorConfig):
27+
return "correct"
28+
else:
29+
if isinstance(objective, util.CompilationFailedConfig):
30+
return "compile"
31+
elif isinstance(objective, util.RuntimeFailedConfig):
32+
return "runtime"
33+
else:
34+
return "constraints"
35+
36+
2437
def store_output_file(output_filename, results, tune_params, objective="time"):
2538
if output_filename[-5:] != ".json":
2639
output_filename += ".json"
2740

28-
timing_keys = ["compile_time", "benchmark_time", "framework_time", "strategy_time", "verification_time"]
29-
not_measurement_keys = list(tune_params.keys()) + timing_keys + ["timestamp"] + ["times"]
41+
timing_keys = [
42+
"compile_time", "benchmark_time", "framework_time", "strategy_time",
43+
"verification_time"
44+
]
45+
not_measurement_keys = list(
46+
tune_params.keys()) + timing_keys + ["timestamp"] + ["times"]
3047

3148
output_data = []
3249

@@ -35,8 +52,10 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
3552
out = {}
3653

3754
out["timestamp"] = result["timestamp"]
38-
out["configuration"] = { k: v
39-
for k, v in result.items() if k in tune_params }
55+
out["configuration"] = {
56+
k: v
57+
for k, v in result.items() if k in tune_params
58+
}
4059

4160
# collect configuration specific timings
4261
timings = dict()
@@ -49,15 +68,7 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
4968
out["times"] = timings
5069

5170
# encode the validity of the configuration
52-
if not isinstance(result[objective], util.ErrorConfig):
53-
out["invalidity"] = "correct"
54-
else:
55-
if isinstance(result[objective], util.CompilationFailedConfig):
56-
out["invalidity"] = "compile"
57-
elif isinstance(result[objective], util.RuntimeFailedConfig):
58-
out["invalidity"] = "runtime"
59-
else:
60-
out["invalidity"] = "constraints"
71+
out["invalidity"] = get_configuration_validity(result[objective])
6172

6273
# Kernel Tuner does not support producing results of configs that fail the correctness check
6374
# therefore correctness is always 1
@@ -66,11 +77,11 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
6677
# measurements gathers everything that was measured
6778
measurements = []
6879
for key, value in result.items():
69-
if not key in not_measurement_keys:
70-
if key.startswith("time"):
71-
measurements.append(dict(name=key, value=value, unit="ms"))
72-
else:
73-
measurements.append(dict(name=key, value=value, unit=""))
80+
if key not in not_measurement_keys:
81+
measurements.append(
82+
dict(name=key,
83+
value=value,
84+
unit="ms" if key.startswith("time") else ""))
7485
out["measurements"] = measurements
7586

7687
# objectives
@@ -105,12 +116,14 @@ def get_dependencies(package='kernel_tuner'):
105116

106117
def get_device_query(target):
107118
if target == "nvidia":
108-
nvidia_smi_out = subprocess.run(["nvidia-smi", "--query", "-x"], capture_output=True)
119+
nvidia_smi_out = subprocess.run(["nvidia-smi", "--query", "-x"],
120+
capture_output=True)
109121
nvidia_smi = xmltodict.parse(nvidia_smi_out.stdout)
110122
del nvidia_smi["nvidia_smi_log"]["gpu"]["processes"]
111123
return nvidia_smi
112124
elif target == "amd":
113-
rocm_smi_out = subprocess.run(["rocm-smi", "--showallinfo", "--json"], capture_output=True)
125+
rocm_smi_out = subprocess.run(["rocm-smi", "--showallinfo", "--json"],
126+
capture_output=True)
114127
return json.loads(rocm_smi_out.stdout)
115128
else:
116129
raise ValueError("get_device_query target not supported")
@@ -128,7 +141,8 @@ def store_metadata_file(metadata_filename, target="nvidia"):
128141
# only works if nvidia-smi (for NVIDIA) or rocm-smi (for AMD) is present, raises FileNotFoundError when not present
129142
device_query = get_device_query(target)
130143

131-
metadata["environment"] = dict(device_query=device_query, requirements=get_dependencies())
144+
metadata["environment"] = dict(device_query=device_query,
145+
requirements=get_dependencies())
132146

133147
# write metadata to JSON file
134148
version, _ = output_file_schema("metadata")

0 commit comments

Comments
 (0)