Skip to content
This repository was archived by the owner on Apr 28, 2023. It is now read-only.

Commit a5d9493

Browse files
authored
Merge pull request #81 from facebookresearch/proper-check-file
Correctly check the autotuner cache file exists
2 parents bfa1a61 + 5e0972d commit a5d9493

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

tensor_comprehensions/tc_unit.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,14 @@
4141
###############################################################################
4242
# Some helper functions
4343
###############################################################################
44+
def check_cache_file_exists(cache_file):
45+
# for autotuning, we save two files: .cuda and .options, we will check that
46+
# these two files exists for the validity of cache
47+
if os.path.exists(cache_file + ".options") and os.path.exists(cache_file + ".cuda"):
48+
return True
49+
return False
50+
51+
4452
def get_options_from_cache_file(name, *inputs, **kwargs):
4553
options = None
4654
if "cache" in kwargs and kwargs["cache"] and isinstance(kwargs["cache"], str):
@@ -49,12 +57,12 @@ def get_options_from_cache_file(name, *inputs, **kwargs):
4957
if "training" in kwargs and kwargs["training"]:
5058
if (kwargs["type"] == "backward"):
5159
cache_file = cache_file + "_backward"
52-
if "tuner" in kwargs and os.path.exists(cache_file):
60+
if "tuner" in kwargs and check_cache_file_exists(cache_file):
5361
tuner = kwargs["tuner"]
5462
loaded_options = tuner.load(cache_file, name, list(inputs), 1)
5563
if len(loaded_options) > 0:
5664
options = loaded_options[0]
57-
elif os.path.exists(cache_file):
65+
elif check_cache_file_exists(cache_file):
5866
tuner = TcAutotuner(kwargs["tc_lang"])
5967
options = tuner.load(cache_file, name, list(inputs))
6068
return options
@@ -225,7 +233,6 @@ def autotune(self, *inputs, **kwargs):
225233
kwargs.pop("name", None)
226234
backward = True if backward_name is not None else False
227235
hash_key = get_tc_hash_key(name, *input_tensors)
228-
229236
# lookup for the options in the cache. Whenever we make the call to
230237
# autotune, tuning must happen. But if the kernel has been tuned earlier
231238
# then we can use previous options to seed the tuning.
@@ -592,6 +599,10 @@ def autotune(self, *inputs, **kwargs):
592599
kwargs.update(self.kwargs_define)
593600
if self.tuner is None:
594601
self.tuner = TcAutotuner(self.lang, **kwargs)
602+
else:
603+
# we do the init again so that the autotuner parameters are updated
604+
# properly if users change them
605+
self.tuner.__init__(self.lang, **kwargs)
595606
return self.tuner.autotune(*inputs, **kwargs)
596607

597608
###############################################################################

0 commit comments

Comments
 (0)