@@ -88,9 +88,6 @@ def __call__(self, *args, **kwargs):
8888 if kwargs .get ("run_config" , None ) is not None :
8989 return self .fn (* args , ** kwargs )
9090
91- if self .disable_autotune :
92- return self .fn (* args , ** kwargs )
93-
9491 rank_id = 0 if not dist .is_initialized () else get_global_rank ()
9592 world_size = 1 if not dist .is_initialized () else get_global_world_size ()
9693
@@ -146,11 +143,12 @@ def _try_load_cache(self, static_key):
146143
147144 cache_file = os .path .join (self .cache_dir , KernelConfigs .get_config_file_name (static_key ))
148145 if os .path .exists (cache_file ):
146+ logger .info (f"Loading cached configs for { self .kernel_name } - { static_key } " )
149147 with open (cache_file , "rb" ) as f :
150148 self .cached_configs [static_key ] = orjson .loads (f .read ())
151149 return
152150
153- def _bench (self , * args , n_repeat = 3 , n_retries = 1 , ** kwargs ):
151+ def _bench (self , * args , n_repeat = 3 , n_retries = 5 , ** kwargs ):
154152 from triton .compiler .errors import CompileTimeAssertionFailure
155153 from triton .runtime .errors import OutOfResources , PTXASError
156154
0 commit comments