Skip to content

Commit 64e2b17

Browse files
add support in cache for numpy objects in tunable parameters
1 parent e4bdb2e commit 64e2b17

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

kernel_tuner/util.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,17 @@ def compile_restrictions(restrictions: list, tune_params: dict):
708708
return func
709709

710710

711+
class NpEncoder(json.JSONEncoder):
712+
def default(self, obj):
713+
if isinstance(obj, np.integer):
714+
return int(obj)
715+
if isinstance(obj, np.floating):
716+
return float(obj)
717+
if isinstance(obj, np.ndarray):
718+
return obj.tolist()
719+
return super(NpEncoder, self).default(obj)
720+
721+
711722
def process_cache(cache, kernel_options, tuning_options, runner):
712723
"""cache file for storing tuned configurations
713724
@@ -747,7 +758,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
747758
c["tune_params"] = tuning_options.tune_params
748759
c["cache"] = {}
749760

750-
contents = json.dumps(c, indent="")[:-3] # except the last "}\n}"
761+
contents = json.dumps(c, cls=NpEncoder, indent="")[:-3] # except the last "}\n}"
751762

752763
# write the header to the cachefile
753764
with open(cache, "w") as cachefile:

0 commit comments

Comments
 (0)