Skip to content

Commit d2d8256

Browse files
add fallback logic of triton autotune
1 parent aff4049 commit d2d8256

File tree

3 files changed

+36
-1
lines changed

3 files changed

+36
-1
lines changed

lightllm/common/triton_utils/autotuner.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from lightllm.utils.device_utils import get_current_device_name
1313
from lightllm.utils.log_utils import init_logger
1414
from typing import Callable, Optional, Union, List
15-
from lightllm.utils.envs_utils import get_triton_autotune_level
15+
from lightllm.utils.envs_utils import get_env_start_args, get_triton_autotune_level
1616
from lightllm.common.kernel_config import KernelConfigs
1717
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node
1818

@@ -218,6 +218,35 @@ def _try_load_cache(self, static_key):
218218
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
219219
with open(cache_file, "rb") as f:
220220
self.cached_configs[static_key] = orjson.loads(f.read())
221+
elif get_env_start_args().enable_kernel_config_fallback:
222+
# list the all triton versions dir
223+
possilble_triton_versions = os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs"))
224+
# get the current triton version
225+
current_triton_version = get_triton_version()
226+
# try sort by the distance between current triton version and possilble triton versions
227+
possilble_triton_versions = sorted(
228+
possilble_triton_versions,
229+
key=lambda x: abs(
230+
int(x.replace("triton_", "").replace(".", ""))
231+
- int(current_triton_version.replace("triton_", "").replace(".", ""))
232+
),
233+
)
234+
for triton_version in possilble_triton_versions:
235+
fallback_cache_file = os.path.join(
236+
Path(__file__).parent,
237+
"autotune_kernel_configs",
238+
triton_version,
239+
get_current_device_name(),
240+
self.kernel_name,
241+
KernelConfigs.get_config_file_name(static_key),
242+
)
243+
if os.path.exists(fallback_cache_file):
244+
logger.warning(
245+
f"Fallback loading cached configs for {self.kernel_name} - {static_key} "
246+
f"from triton version {triton_version}"
247+
)
248+
with open(fallback_cache_file, "rb") as f:
249+
self.cached_configs[static_key] = orjson.loads(f.read())
221250
return True
222251

223252
def kernel_warmup(self, static_key, *args, **kwargs):

lightllm/server/api_cli.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,11 @@ def make_argument_parser() -> argparse.ArgumentParser:
332332
action="store_true",
333333
help="""inference backend will use the fa3 attention kernel for prefill and decode""",
334334
)
335+
parser.add_argument(
336+
"--enable_kernel_config_fallback",
337+
action="store_true",
338+
help="""Whether to enable kernel config fallback when triton version is not compatible.""",
339+
)
335340
parser.add_argument(
336341
"--cache_capacity", type=int, default=200, help="cache server capacity for multimodal resources"
337342
)

lightllm/server/core/objs/start_args_type.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,4 @@ class StartArgs:
131131

132132
# kernel setting
133133
enable_fa3: bool = field(default=False)
134+
enable_kernel_config_fallback: bool = field(default=False)

0 commit comments

Comments
 (0)