@@ -528,7 +528,8 @@ def get_kernel_string(kernel_source, params=None):
528528 kernel_string = read_file (kernel_source )
529529 elif isinstance (kernel_source , str ):
530530 if looks_like_a_filename (kernel_source ):
531- kernel_string = read_file (kernel_source ) or kernel_source
531+ with open (kernel_source , "r" ) as f :
532+ kernel_string = f .read ()
532533 else :
533534 kernel_string = kernel_source
534535 else :
@@ -1123,6 +1124,10 @@ def compile_restrictions(
11231124 noncompiled_restrictions .append ((r , [], r ))
11241125 return noncompiled_restrictions + compiled_restrictions
11251126
1127+ def check_matching_problem_size (cached_problem_size , problem_size ):
1128+ """Check the if requested problem size matches the problem size in the cache."""
1129+ if not (np .array (cached_problem_size ) == np .array (problem_size )).all ():
1130+ raise ValueError (f"Cannot load cache which contains results for different problem_size, cache: { cached_problem_size } , requested: { problem_size } " )
11261131
11271132def process_cache (cache , kernel_options , tuning_options , runner ):
11281133 """Cache file for storing tuned configurations.
@@ -1193,18 +1198,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
11931198 f"Cannot load cache which contains results for different kernel (cache: { cached_data ['kernel_name' ]} , actual: { kernel_options .kernel_name } )"
11941199 )
11951200 if "problem_size" in cached_data and not callable (kernel_options .problem_size ):
1196- # if it's a single value, convert to an array
1197- if isinstance (cached_data ["problem_size" ], int ):
1198- cached_data ["problem_size" ] = [cached_data ["problem_size" ]]
1199- # if problem_size is not iterable, compare directly
1200- if not hasattr (kernel_options .problem_size , "__iter__" ):
1201- if cached_data ["problem_size" ] != kernel_options .problem_size :
1202- raise ValueError ("Cannot load cache which contains results for different problem_size" )
1203- # else (problem_size is iterable)
1204- # cache returns list, problem_size is likely a tuple. Therefore, the next check
1205- # checks the equality of all items in the list/tuples individually
1206- elif not all ([i == j for i , j in zip (cached_data ["problem_size" ], kernel_options .problem_size )]):
1207- raise ValueError ("Cannot load cache which contains results for different problem_size" )
1201+ check_matching_problem_size (cached_data ["problem_size" ], kernel_options .problem_size )
12081202 if cached_data ["tune_params_keys" ] != list (tuning_options .tune_params .keys ()):
12091203 if all (key in tuning_options .tune_params for key in cached_data ["tune_params_keys" ]):
12101204 raise ValueError (
0 commit comments