41
41
###############################################################################
42
42
# Some helper functions
43
43
###############################################################################
44
+ def check_cache_file_exists (cache_file ):
45
+ # for autotuning, we save two files: .cuda and .options, we will check that
46
+ # these two files exists for the validity of cache
47
+ if os .path .exists (cache_file + ".options" ) and os .path .exists (cache_file + ".cuda" ):
48
+ return True
49
+ return False
50
+
51
+
44
52
def get_options_from_cache_file (name , * inputs , ** kwargs ):
45
53
options = None
46
54
if "cache" in kwargs and kwargs ["cache" ] and isinstance (kwargs ["cache" ], str ):
@@ -49,12 +57,12 @@ def get_options_from_cache_file(name, *inputs, **kwargs):
49
57
if "training" in kwargs and kwargs ["training" ]:
50
58
if (kwargs ["type" ] == "backward" ):
51
59
cache_file = cache_file + "_backward"
52
- if "tuner" in kwargs and os . path . exists (cache_file ):
60
+ if "tuner" in kwargs and check_cache_file_exists (cache_file ):
53
61
tuner = kwargs ["tuner" ]
54
62
loaded_options = tuner .load (cache_file , name , list (inputs ), 1 )
55
63
if len (loaded_options ) > 0 :
56
64
options = loaded_options [0 ]
57
- elif os . path . exists (cache_file ):
65
+ elif check_cache_file_exists (cache_file ):
58
66
tuner = TcAutotuner (kwargs ["tc_lang" ])
59
67
options = tuner .load (cache_file , name , list (inputs ))
60
68
return options
@@ -225,7 +233,6 @@ def autotune(self, *inputs, **kwargs):
225
233
kwargs .pop ("name" , None )
226
234
backward = True if backward_name is not None else False
227
235
hash_key = get_tc_hash_key (name , * input_tensors )
228
-
229
236
# lookup for the options in the cache. Whenever we make the call to
230
237
# autotune, tuning must happen. But if the kernel has been tuned earlier
231
238
# then we can use previous options to seed the tuning.
@@ -592,6 +599,10 @@ def autotune(self, *inputs, **kwargs):
592
599
kwargs .update (self .kwargs_define )
593
600
if self .tuner is None :
594
601
self .tuner = TcAutotuner (self .lang , ** kwargs )
602
+ else :
603
+ # we do the init again so that the autotuner parameters are updated
604
+ # properly if users change them
605
+ self .tuner .__init__ (self .lang , ** kwargs )
595
606
return self .tuner .autotune (* inputs , ** kwargs )
596
607
597
608
###############################################################################
0 commit comments