1+ import re
12import triton
23import orjson
34import os
1112from frozendict import frozendict
1213from lightllm .utils .device_utils import get_current_device_name
1314from lightllm .utils .log_utils import init_logger
14- from typing import Callable , Optional , Union , List
15+ from typing import Callable , Optional , Tuple , Union , List
1516from lightllm .utils .envs_utils import get_env_start_args , get_triton_autotune_level
1617from lightllm .common .kernel_config import KernelConfigs
1718from lightllm .utils .dist_utils import get_global_world_size , get_global_rank , get_current_rank_in_node
@@ -219,19 +220,52 @@ def _try_load_cache(self, static_key):
219220 with open (cache_file , "rb" ) as f :
220221 self .cached_configs [static_key ] = orjson .loads (f .read ())
221222 elif get_env_start_args ().enable_kernel_config_fallback :
222- # list the all triton versions dir
223- possilble_triton_versions = os .listdir (os .path .join (Path (__file__ ).parent , "autotune_kernel_configs" ))
224- # get the current triton version
223+
224+ def parse_triton_version_tag (tag : str ) -> Optional [Tuple [int , int , int ]]:
225+ """
226+ Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
227+ Returns None if invalid.
228+ """
229+ match = re .match (r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$" , tag )
230+ if not match :
231+ return None
232+ x , y , z = match .groups ()
233+ return (int (x ), int (y ), int (z ) if z is not None else 0 )
234+
235+ def version_distance (v1 : Tuple [int , int , int ], v2 : Tuple [int , int , int ]) -> int :
236+ """
237+ Compute weighted distance: major * 1e6 + minor * 1e3 + patch
238+ Ensures lexicographic ordering.
239+ """
240+ return abs ((v1 [0 ] - v2 [0 ]) * 1_000_000 + (v1 [1 ] - v2 [1 ]) * 1_000 + (v1 [2 ] - v2 [2 ]))
241+
225242 current_triton_version = get_triton_version ()
226- # try sort by the distance between current triton version and possilble triton versions
227- possilble_triton_versions = sorted (
228- possilble_triton_versions ,
229- key = lambda x : abs (
230- int (x .replace ("triton_" , "" ).replace ("." , "" ))
231- - int (current_triton_version .replace ("triton_" , "" ).replace ("." , "" ))
232- ),
233- )
234- for triton_version in possilble_triton_versions :
243+ current_parsed = parse_triton_version_tag (current_triton_version )
244+ if current_parsed is None :
245+ logger .error ("Unable to parse current Triton version. Triton may not be installed properly." )
246+ possible_dirs = [
247+ d
248+ for d in os .listdir (os .path .join (Path (__file__ ).parent , "autotune_kernel_configs" ))
249+ if d .startswith ("triton_" )
250+ ]
251+ possible_dirs .sort ()
252+ else :
253+ config_dir = os .path .join (Path (__file__ ).parent , "autotune_kernel_configs" )
254+ possible_dirs = []
255+ for d in os .listdir (config_dir ):
256+ if not d .startswith ("triton_" ):
257+ continue
258+ parsed = parse_triton_version_tag (d )
259+ if parsed is not None :
260+ dist = version_distance (parsed , current_parsed )
261+ possible_dirs .append ((dist , d , parsed ))
262+ else :
263+ logger .debug (f"Skipping invalid version directory: { d } " )
264+ possible_dirs .sort (key = lambda x : x [0 ])
265+ possible_dirs = [d for _ , d , _ in possible_dirs ]
266+
267+ loaded = False
268+ for triton_version in possible_dirs :
235269 fallback_cache_file = os .path .join (
236270 Path (__file__ ).parent ,
237271 "autotune_kernel_configs" ,
@@ -241,12 +275,20 @@ def _try_load_cache(self, static_key):
241275 KernelConfigs .get_config_file_name (static_key ),
242276 )
243277 if os .path .exists (fallback_cache_file ):
244- logger .warning (
245- f"Fallback loading cached configs for { self .kernel_name } - { static_key } "
246- f"from triton version { triton_version } "
247- )
248- with open (fallback_cache_file , "rb" ) as f :
249- self .cached_configs [static_key ] = orjson .loads (f .read ())
278+ try :
279+ logger .warning (
280+ f"Fallback loading cached configs for { self .kernel_name } - { static_key } "
281+ f"from triton version { triton_version } (current: { current_triton_version } )"
282+ )
283+ with open (fallback_cache_file , "rb" ) as f :
284+ self .cached_configs [static_key ] = orjson .loads (f .read ())
285+ loaded = True
286+ break
287+ except Exception as e :
288+ logger .error (f"Failed to load fallback config from { fallback_cache_file } : { e } " )
289+
290+ if not loaded :
291+ logger .info (f"No fallback config found for { self .kernel_name } - { static_key } " )
250292 return True
251293
252294 def kernel_warmup (self , static_key , * args , ** kwargs ):
0 commit comments