@@ -728,6 +728,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
728728
729729 { device_name: "name of device"
730730 kernel_name: "name of kernel"
731+ problem_size: (int, int, int)
731732 tune_params_keys: list
732733 tune_params:
733734 cache: {
@@ -754,6 +755,7 @@ def process_cache(cache, kernel_options, tuning_options, runner):
754755 c = OrderedDict ()
755756 c ["device_name" ] = runner .dev .name
756757 c ["kernel_name" ] = kernel_options .kernel_name
758+ c ["problem_size" ] = kernel_options .problem_size if not callable (kernel_options .problem_size ) else "callable"
757759 c ["tune_params_keys" ] = list (tuning_options .tune_params .keys ())
758760 c ["tune_params" ] = tuning_options .tune_params
759761 c ["cache" ] = {}
@@ -780,6 +782,11 @@ def process_cache(cache, kernel_options, tuning_options, runner):
780782 raise ValueError ("Cannot load cache which contains results for different device" )
781783 if cached_data ["kernel_name" ] != kernel_options .kernel_name :
782784 raise ValueError ("Cannot load cache which contains results for different kernel" )
785+ if "problem_size" in cached_data and not callable (kernel_options .problem_size ):
786+ # cache returns list, problem_size is likely a tuple. Therefore, the next check
787+ # checks the equality of all items in the list/tuples individually
788+ if not all ([i == j for i , j in zip (cached_data ["problem_size" ], kernel_options .problem_size )]):
789+ raise ValueError ("Cannot load cache which contains results for different problem_size" )
783790 if cached_data ["tune_params_keys" ] != list (tuning_options .tune_params .keys ()):
784791 raise ValueError ("Cannot load cache which contains results obtained with different tunable parameters" )
785792
0 commit comments