Skip to content

Commit 02544eb

Browse files
remove redundant JSON encoder
1 parent 7c6f709 commit 02544eb

File tree

2 files changed

+15
-24
lines changed

2 files changed

+15
-24
lines changed

kernel_tuner/file_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def store_output_file(output_filename, results, tune_params, objective="time"):
131131
version, _ = output_file_schema("results")
132132
output_json = dict(results=output_data, schema_version=version)
133133
with open(output_filename, 'w+') as fh:
134-
json.dump(output_json, fh)
134+
json.dump(output_json, fh, cls=util.NpEncoder)
135135

136136

137137
def get_dependencies(package='kernel_tuner'):

kernel_tuner/util.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,19 @@ class RuntimeFailedConfig(ErrorConfig):
4949
pass
5050

5151

52+
class NpEncoder(json.JSONEncoder):
53+
""" Class we use for dumping Numpy objects to JSON """
54+
55+
def default(self, obj):
56+
if isinstance(obj, np.integer):
57+
return int(obj)
58+
if isinstance(obj, np.floating):
59+
return float(obj)
60+
if isinstance(obj, np.ndarray):
61+
return obj.tolist()
62+
return super(NpEncoder, self).default(obj)
63+
64+
5265
class TorchPlaceHolder():
5366

5467
def __init__(self):
@@ -725,18 +738,6 @@ def compile_restrictions(restrictions: list, tune_params: dict):
725738
return func
726739

727740

728-
class NpEncoder(json.JSONEncoder):
729-
730-
def default(self, obj):
731-
if isinstance(obj, np.integer):
732-
return int(obj)
733-
if isinstance(obj, np.floating):
734-
return float(obj)
735-
if isinstance(obj, np.ndarray):
736-
return obj.tolist()
737-
return super(NpEncoder, self).default(obj)
738-
739-
740741
def process_cache(cache, kernel_options, tuning_options, runner):
741742
"""cache file for storing tuned configurations
742743
@@ -871,16 +872,6 @@ def close_cache(cache):
871872
def store_cache(key, params, tuning_options):
872873
""" stores a new entry (key, params) to the cachefile """
873874

874-
# create converter for dumping numpy objects to JSON
875-
def JSONconverter(obj):
876-
if isinstance(obj, np.integer):
877-
return int(obj)
878-
if isinstance(obj, np.floating):
879-
return float(obj)
880-
if isinstance(obj, np.ndarray):
881-
return obj.tolist()
882-
return obj.__str__()
883-
884875
#logging.debug('store_cache called, cache=%s, cachefile=%s' % (tuning_options.cache, tuning_options.cachefile))
885876
if isinstance(tuning_options.cache, dict):
886877
if not key in tuning_options.cache:
@@ -894,7 +885,7 @@ def JSONconverter(obj):
894885

895886
if tuning_options.cachefile:
896887
with open(tuning_options.cachefile, "a") as cachefile:
897-
cachefile.write("\n" + json.dumps({ key: output_params }, default=JSONconverter)[1:-1] + ",")
888+
cachefile.write("\n" + json.dumps({ key: output_params }, cls=NpEncoder)[1:-1] + ",")
898889

899890

900891
def dump_cache(obj: str, tuning_options):

0 commit comments

Comments
 (0)