|
12 | 12 | from lightllm.utils.device_utils import get_current_device_name |
13 | 13 | from lightllm.utils.log_utils import init_logger |
14 | 14 | from typing import Callable, Optional, Union, List |
15 | | -from lightllm.utils.envs_utils import get_triton_autotune_level |
| 15 | +from lightllm.utils.envs_utils import get_env_start_args, get_triton_autotune_level |
16 | 16 | from lightllm.common.kernel_config import KernelConfigs |
17 | 17 | from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node |
18 | 18 |
|
@@ -218,6 +218,35 @@ def _try_load_cache(self, static_key): |
218 | 218 | logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}") |
219 | 219 | with open(cache_file, "rb") as f: |
220 | 220 | self.cached_configs[static_key] = orjson.loads(f.read()) |
| 221 | + elif get_env_start_args().enable_kernel_config_fallback: |
| 222 | + # list the all triton versions dir |
| 223 | + possilble_triton_versions = os.listdir(os.path.join(Path(__file__).parent, "autotune_kernel_configs")) |
| 224 | + # get the current triton version |
| 225 | + current_triton_version = get_triton_version() |
| 226 | + # try sort by the distance between current triton version and possilble triton versions |
| 227 | + possilble_triton_versions = sorted( |
| 228 | + possilble_triton_versions, |
| 229 | + key=lambda x: abs( |
| 230 | + int(x.replace("triton_", "").replace(".", "")) |
| 231 | + - int(current_triton_version.replace("triton_", "").replace(".", "")) |
| 232 | + ), |
| 233 | + ) |
| 234 | + for triton_version in possilble_triton_versions: |
| 235 | + fallback_cache_file = os.path.join( |
| 236 | + Path(__file__).parent, |
| 237 | + "autotune_kernel_configs", |
| 238 | + triton_version, |
| 239 | + get_current_device_name(), |
| 240 | + self.kernel_name, |
| 241 | + KernelConfigs.get_config_file_name(static_key), |
| 242 | + ) |
| 243 | + if os.path.exists(fallback_cache_file): |
| 244 | + logger.warning( |
| 245 | + f"Fallback loading cached configs for {self.kernel_name} - {static_key} " |
| 246 | + f"from triton version {triton_version}" |
| 247 | + ) |
| 248 | + with open(fallback_cache_file, "rb") as f: |
| 249 | + self.cached_configs[static_key] = orjson.loads(f.read()) |
221 | 250 | return True |
222 | 251 |
|
223 | 252 | def kernel_warmup(self, static_key, *args, **kwargs): |
|
0 commit comments