Skip to content

Commit 0872088

Browse files
improve logic
1 parent d2d8256 commit 0872088

File tree

1 file changed

+61
-19
lines changed

1 file changed

+61
-19
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import re
12
import triton
23
import orjson
34
import os
@@ -11,7 +12,7 @@
1112
from frozendict import frozendict
1213
from lightllm.utils.device_utils import get_current_device_name
1314
from lightllm.utils.log_utils import init_logger
14-
from typing import Callable, Optional, Union, List
15+
from typing import Callable, Optional, Tuple, Union, List
1516
from lightllm.utils.envs_utils import get_env_start_args, get_triton_autotune_level
1617
from lightllm.common.kernel_config import KernelConfigs
1718
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node
@@ -219,19 +220,52 @@ def _try_load_cache(self, static_key):
219220
with open(cache_file, "rb") as f:
220221
self.cached_configs[static_key] = orjson.loads(f.read())
221222
elif get_env_start_args().enable_kernel_config_fallback:
222-
# list the all triton versions dir
223-
possilble_triton_versions = os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs"))
224-
# get the current triton version
223+
224+
def parse_triton_version_tag(tag: str) -> Optional[Tuple[int, int, int]]:
225+
"""
226+
Parse "triton_X.Y.Z" or "triton_X.Y" to (X, Y, Z), Z defaults to 0.
227+
Returns None if invalid.
228+
"""
229+
match = re.match(r"^triton_(\d+)\.(\d+)(?:\.(\d+))?$", tag)
230+
if not match:
231+
return None
232+
x, y, z = match.groups()
233+
return (int(x), int(y), int(z) if z is not None else 0)
234+
235+
def version_distance(v1: Tuple[int, int, int], v2: Tuple[int, int, int]) -> int:
236+
"""
237+
Compute weighted distance: major * 1e6 + minor * 1e3 + patch
238+
Ensures lexicographic ordering.
239+
"""
240+
return abs((v1[0] - v2[0]) * 1_000_000 + (v1[1] - v2[1]) * 1_000 + (v1[2] - v2[2]))
241+
225242
current_triton_version = get_triton_version()
226-
# try sort by the distance between current triton version and possilble triton versions
227-
possilble_triton_versions = sorted(
228-
possilble_triton_versions,
229-
key=lambda x: abs(
230-
int(x.replace("triton_", "").replace(".", ""))
231-
- int(current_triton_version.replace("triton_", "").replace(".", ""))
232-
),
233-
)
234-
for triton_version in possilble_triton_versions:
243+
current_parsed = parse_triton_version_tag(current_triton_version)
244+
if current_parsed is None:
245+
logger.error("Unable to parse current Triton version. Triton may not be installed properly.")
246+
possible_dirs = [
247+
d
248+
for d in os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs"))
249+
if d.startswith("triton_")
250+
]
251+
possible_dirs.sort()
252+
else:
253+
config_dir = os.path.join(Path(__file__).parent, "autotune_kernel_configs")
254+
possible_dirs = []
255+
for d in os.listdir(config_dir):
256+
if not d.startswith("triton_"):
257+
continue
258+
parsed = parse_triton_version_tag(d)
259+
if parsed is not None:
260+
dist = version_distance(parsed, current_parsed)
261+
possible_dirs.append((dist, d, parsed))
262+
else:
263+
logger.debug(f"Skipping invalid version directory: {d}")
264+
possible_dirs.sort(key=lambda x: x[0])
265+
possible_dirs = [d for _, d, _ in possible_dirs]
266+
267+
loaded = False
268+
for triton_version in possible_dirs:
235269
fallback_cache_file = os.path.join(
236270
Path(__file__).parent,
237271
"autotune_kernel_configs",
@@ -241,12 +275,20 @@ def _try_load_cache(self, static_key):
241275
KernelConfigs.get_config_file_name(static_key),
242276
)
243277
if os.path.exists(fallback_cache_file):
244-
logger.warning(
245-
f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
246-
f"from triton version {triton_version}"
247-
)
248-
with open(fallback_cache_file, "rb") as f:
249-
self.cached_configs[static_key] = orjson.loads(f.read())
278+
try:
279+
logger.warning(
280+
f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
281+
f"from triton version {triton_version} (current: {current_triton_version})"
282+
)
283+
with open(fallback_cache_file, "rb") as f:
284+
self.cached_configs[static_key] = orjson.loads(f.read())
285+
loaded = True
286+
break
287+
except Exception as e:
288+
logger.error(f"Failed to load fallback config from {fallback_cache_file}: {e}")
289+
290+
if not loaded:
291+
logger.info(f"No fallback config found for {self.kernel_name} - {static_key}")
250292
return True
251293

252294
def kernel_warmup(self, static_key, *args, **kwargs):

0 commit comments

Comments
 (0)