Skip to content

Commit d6695ef

Browse files
committed
add AutotuneLevel for more detailed autotune
1 parent 237ae00 commit d6695ef

File tree

8 files changed

+88
-40
lines changed

8 files changed

+88
-40
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from lightllm.utils.envs_utils import get_env_start_args
2525
from lightllm.distributed.communication_op import dist_group_manager
2626
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
27+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2728
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
28-
from lightllm.utils.envs_utils import set_model_init_status, is_triton_autotune_enabled, disable_triton_autotune
29+
from lightllm.utils.envs_utils import set_model_init_status, set_triton_autotune_level, get_triton_autotune_level
2930
from lightllm.utils.infer_utils import post_empty_cache
3031

3132
logger = init_logger(__name__)
@@ -731,7 +732,7 @@ def autotune_layers(self):
731732
@torch.no_grad()
732733
@post_empty_cache
733734
def _autotune_warmup(self):
734-
if not is_triton_autotune_enabled():
735+
if get_triton_autotune_level() in [AutotuneLevel.NO_AUTOTUNE, AutotuneLevel.CLOSE_AUTOTUNE]:
735736
return
736737

737738
torch.distributed.barrier()
@@ -794,7 +795,8 @@ def _autotune_warmup(self):
794795
torch.cuda.empty_cache()
795796
self.layers_num = layer_num_bak
796797
torch.distributed.barrier()
797-
disable_triton_autotune()
798+
if get_triton_autotune_level() not in [AutotuneLevel.AUTOTUNE_RUNTIME, AutotuneLevel.AUTOTUNE_RUNTIME_OVERWRITE]:
799+
set_triton_autotune_level(AutotuneLevel.NO_AUTOTUNE)
798800

799801
@final
800802
@torch.no_grad()

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
from .moe_sum_reduce import moe_sum_reduce
3636
from lightllm.common.quantization.triton_quant.fp8.fp8act_quant_kernel import per_token_group_quant_fp8
3737
from lightllm.utils.torch_ops_utils import direct_register_custom_op
38-
from lightllm.common.triton_utils.autotuner import autotune
38+
from lightllm.common.triton_utils.autotuner import autotune, closest_pow_of_2
3939

4040
FFN_MOE_CHUNK_SIZE = 32 * 1024
4141

@@ -492,7 +492,7 @@ def _get_grouped_matmul_configs():
492492
kernel_name="grouped_matmul:v1",
493493
configs_gen_func=_get_grouped_matmul_configs,
494494
static_key_func=_get_grouped_matmul_static_key,
495-
run_key_func=lambda token_inputs: token_inputs.shape[0],
495+
run_key_func=lambda token_inputs: closest_pow_of_2(token_inputs.shape[0]),
496496
mutates_args=["out"],
497497
)
498498
def grouped_matmul(

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import triton
44
import triton.language as tl
55
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
6-
from lightllm.common.triton_utils.autotuner import autotune
6+
from lightllm.common.triton_utils.autotuner import autotune, closest_pow_of_2
77

88

99
@triton.jit
@@ -81,7 +81,7 @@ def _get_silu_and_mul_static_key(input: torch.Tensor, output: torch.Tensor):
8181
kernel_name="silu_and_mul_fwd:v1",
8282
configs_gen_func=_get_silu_and_mul_configs,
8383
static_key_func=_get_silu_and_mul_static_key,
84-
run_key_func=lambda input: input.shape[0],
84+
run_key_func=lambda input: closest_pow_of_2(input.shape[0]),
8585
mutates_args=["output"],
8686
)
8787
def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, run_config=None):

lightllm/common/fused_moe/moe_sum_reduce.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import triton.language as tl
55
from .moe_sum_recude_config import MoeSumReduceKernelConfig
66
from typing import Any, Callable, Dict, Optional, Tuple
7-
from lightllm.common.triton_utils.autotuner import autotune
7+
from lightllm.common.triton_utils.autotuner import autotune, closest_pow_of_2
88

99

1010
@triton.jit
@@ -66,7 +66,7 @@ def _get_moe_sum_reduce_configs():
6666
kernel_name="moe_sum_reduce:v1",
6767
configs_gen_func=_get_moe_sum_reduce_configs,
6868
static_key_func=_get_moe_sum_reduce_static_key,
69-
run_key_func=lambda input: input.shape[0],
69+
run_key_func=lambda input: closest_pow_of_2(input.shape[0]),
7070
mutates_args=["output"],
7171
)
7272
def moe_sum_reduce(input: torch.Tensor, output: torch.Tensor, run_config: Dict = None):

lightllm/common/quantization/triton_quant/fp8/fp8w8a8_block_gemm_kernel.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from functools import lru_cache
88
from typing import Any, Dict, List, Optional, Tuple
99
from triton import Config
10-
from lightllm.common.triton_utils.autotuner import autotune
10+
from lightllm.common.triton_utils.autotuner import autotune, closest_pow_of_2
1111

1212

1313
class Fp8BlockMMKernelConfig(KernelConfigs):
@@ -180,7 +180,7 @@ def _get_static_key(A, B, block_size, dtype):
180180
kernel_name="w8a8_block_fp8_matmul:v1",
181181
configs_gen_func=get_test_configs,
182182
static_key_func=_get_static_key,
183-
run_key_func=lambda A: A.shape[0],
183+
run_key_func=lambda A: closest_pow_of_2(A.shape[0]),
184184
mutates_args=["C"],
185185
)
186186
def w8a8_block_fp8_matmul(

lightllm/common/triton_utils/autotuner.py

Lines changed: 68 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,33 @@
1212
from lightllm.utils.device_utils import get_current_device_name
1313
from lightllm.utils.log_utils import init_logger
1414
from typing import Callable, Optional, Union, List
15-
from lightllm.utils.envs_utils import is_triton_autotune_enabled
15+
from lightllm.utils.envs_utils import get_triton_autotune_level
1616
from lightllm.common.kernel_config import KernelConfigs
1717
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_rank_in_node
18-
from lightllm.distributed.communication_op import dist_group_manager
1918

2019
logger = init_logger(__name__)
2120

2221

22+
def _get_autotune_group():
23+
from lightllm.distributed.communication_op import dist_group_manager
24+
return dist_group_manager.get_default_group().autotune_group
25+
26+
27+
class AutotuneLevel:
28+
# Do not autotune, only use the config of cached files.
29+
NO_AUTOTUNE = 0
30+
# Autotune if no config is cached.
31+
AUTOTUNE = 1
32+
# Autotune anyway to overwrite the config of cached files.
33+
AUTOTUNE_OVERWRITE = 2
34+
# Auotune in runtime to search for more better config.
35+
AUTOTUNE_RUNTIME = 3
36+
# Autotune in runtime to search for more better config and overwrite the config of cached files.
37+
AUTOTUNE_RUNTIME_OVERWRITE = 4
38+
# Close autotune and not use the config of cached files.
39+
CLOSE_AUTOTUNE = 5
40+
41+
2342
def autotune(
2443
kernel_name: str,
2544
configs_gen_func: Callable[[], List],
@@ -28,6 +47,29 @@ def autotune(
2847
run_key_distance_func: Callable = lambda run_key, config_key: abs(int(run_key) - int(config_key)),
2948
mutates_args: List[str] = [],
3049
):
50+
"""Decorator that constructs and returns an Autotuner wrapper for a Triton kernel.
51+
52+
This decorator configures an Autotuner with the provided configuration
53+
generator and key functions, enabling on-demand benchmarking and caching
54+
of kernel run configurations across runs and processes.
55+
56+
Args:
57+
kernel_name (str): Human-readable kernel name used for logging and cache paths.
58+
configs_gen_func (Callable[[], List]): Function that returns candidate run configurations.
59+
static_key_func (Callable): Function that derives a static key (dict-like) from call arguments.
60+
This key identifies the cache file that stores tuned configs.
61+
run_key_func (Callable): Function that derives a run-time key from call arguments.
62+
This key indexes tuned configs within a static key's cache.
63+
run_key_distance_func (Callable, optional): Distance metric taking ``(run_key, config_key)`` and
64+
returning a comparable value; used to pick the closest config when an exact match is absent.
65+
Defaults to ``abs(int(run_key) - int(config_key))``.
66+
mutates_args (List[str], optional): Names of arguments that can be mutated by the kernel.
67+
During benchmarking, defensive clones are made to avoid side effects. Defaults to ``[]``.
68+
69+
Returns:
70+
Callable: A callable object that wraps the original function and performs autotuning
71+
as needed before invocation.
72+
"""
3173
def decorator(fn):
3274
return Autotuner(
3375
fn=fn,
@@ -53,8 +95,7 @@ def __init__(
5395
run_key_distance_func: Callable = lambda run_key, config_key: abs(int(run_key) - int(config_key)),
5496
mutates_args: List[str] = [],
5597
):
56-
# Whether to use this autotune decorator
57-
self.disable_autotune = not is_triton_autotune_enabled()
98+
self.autotune_level = get_triton_autotune_level()
5899

59100
self.configs_gen_func = configs_gen_func
60101
self.kernel_name = kernel_name
@@ -65,7 +106,6 @@ def __init__(
65106
get_current_device_name(),
66107
self.kernel_name,
67108
)
68-
os.makedirs(self.cache_dir, exist_ok=True)
69109
self.fn = fn
70110
self.static_key_func = static_key_func
71111
self.run_key_func = run_key_func
@@ -81,38 +121,42 @@ def __init__(
81121
]
82122
self._run_key_func_param_names = [name for name, _ in inspect.signature(self.run_key_func).parameters.items()]
83123
self.mutates_args = mutates_args
124+
125+
assert self.autotune_level in [AutotuneLevel.NO_AUTOTUNE, AutotuneLevel.AUTOTUNE, AutotuneLevel.AUTOTUNE_OVERWRITE, AutotuneLevel.AUTOTUNE_RUNTIME, AutotuneLevel.AUTOTUNE_RUNTIME_OVERWRITE, AutotuneLevel.CLOSE_AUTOTUNE]
84126
return
85127

86128
@torch.no_grad()
87129
def __call__(self, *args, **kwargs):
88130
if kwargs.get("run_config", None) is not None:
89131
return self.fn(*args, **kwargs)
90132

91-
if self.disable_autotune:
133+
# if the autotune_level is AutotuneLevel.CLOSE_AUTOTUNE, ignore the autotune
134+
if self.autotune_level == AutotuneLevel.CLOSE_AUTOTUNE:
92135
return self.fn(*args, **kwargs)
93136

94137
rank_id = 0 if not dist.is_initialized() else get_global_rank()
95138
world_size = 1 if not dist.is_initialized() else get_global_world_size()
96139

97-
static_key = self._static_key(*args, **kwargs)
140+
static_key = frozendict(self._static_key(*args, **kwargs))
98141
run_key = str(self._run_key(*args, **kwargs))
99142

100-
# Lazy load
101-
self._try_load_cache(static_key)
143+
# Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
144+
if self.autotune_level not in [AutotuneLevel.AUTOTUNE_OVERWRITE, AutotuneLevel.AUTOTUNE_RUNTIME_OVERWRITE]:
145+
self._try_load_cache(static_key)
102146

103-
if static_key not in self.cached_configs:
147+
if static_key not in self.cached_configs and self.autotune_level == AutotuneLevel.NO_AUTOTUNE:
104148
if (dist.is_initialized() and get_current_rank_in_node() == 0) or not dist.is_initialized():
105149
logger.warning(
106150
f"No kernel config for {self.kernel_name} in {KernelConfigs.get_config_file_name(static_key)}",
107151
)
108152
self.cached_configs[static_key] = {}
109153

110-
if is_triton_autotune_enabled():
154+
if self.autotune_level != AutotuneLevel.NO_AUTOTUNE:
111155
need_tunning = run_key not in self.cached_configs.get(static_key, {})
112156
if world_size > 1:
113157
_need_tunnings = [None for _ in range(world_size)]
114158
dist.all_gather_object(
115-
_need_tunnings, obj=need_tunning, group=dist_group_manager.get_default_group().autotune_group
159+
_need_tunnings, obj=need_tunning, group=_get_autotune_group()
116160
)
117161
need_tunning = any(_need_tunnings)
118162
if need_tunning:
@@ -125,12 +169,12 @@ def __call__(self, *args, **kwargs):
125169
world_size=world_size,
126170
)
127171

128-
if static_key in self.fast_match_configs and run_key in self.fast_match_configs[static_key]:
129-
closest_config = self.fast_match_configs[static_key][run_key]
130-
kwargs["run_config"] = closest_config
172+
fast_for_key = self.fast_match_configs.get(static_key)
173+
if fast_for_key is not None and run_key in fast_for_key:
174+
kwargs["run_config"] = fast_for_key[run_key]
131175
return self.fn(*args, **kwargs)
132176

133-
all_configs = self.cached_configs.get(static_key)
177+
all_configs = self.cached_configs.get(static_key, {})
134178
if len(all_configs) != 0:
135179
closest_config = min(
136180
list(all_configs.items()), key=lambda item: self.run_key_distance_func(run_key, item[0])
@@ -146,6 +190,7 @@ def _try_load_cache(self, static_key):
146190

147191
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
148192
if os.path.exists(cache_file):
193+
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
149194
with open(cache_file, "rb") as f:
150195
self.cached_configs[static_key] = orjson.loads(f.read())
151196
return
@@ -195,7 +240,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
195240
all_keys = [None for _ in range(world_size)]
196241
all_key_str = f"{run_key}_{static_key}"
197242
dist.all_gather_object(
198-
all_keys, obj=all_key_str, group=dist_group_manager.get_default_group().autotune_group
243+
all_keys, obj=all_key_str, group=_get_autotune_group()
199244
)
200245
is_key_all_same = all(all_keys[0] == k for k in all_keys)
201246
if not is_key_all_same:
@@ -237,7 +282,7 @@ def _autotune(self, args, kwargs, static_key, run_key, rank_id, world_size):
237282
dist.all_gather_object(
238283
all_gather_configs,
239284
obj=(best_time, run_key, dict(static_key), best_config),
240-
group=dist_group_manager.get_default_group().autotune_group,
285+
group=_get_autotune_group(),
241286
)
242287
all_gather_configs = sorted(all_gather_configs, key=lambda x: x[0])
243288
key_set = set()
@@ -318,8 +363,7 @@ def _select_args(self, param_names, args, kwargs):
318363

319364
def _static_key(self, *args, **kwargs):
320365
params = self._select_args(self._static_key_func_param_names, args, kwargs)
321-
key = self.static_key_func(*params)
322-
return frozendict(key)
366+
return self.static_key_func(*params)
323367

324368
def _run_key(self, *args, **kwargs):
325369
params = self._select_args(self._run_key_func_param_names, args, kwargs)
@@ -347,3 +391,7 @@ def get_triton_version():
347391
def split_configs(configs, global_rank, global_world_size):
348392
random.Random(0).shuffle(configs)
349393
return configs[global_rank::global_world_size]
394+
395+
396+
def closest_pow_of_2(x):
397+
return triton.next_power_of_two(x - triton.next_power_of_two(x)//4)

lightllm/models/deepseek2/triton_kernel/rotary_emb.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import triton
44
import triton.language as tl
55
import itertools
6-
from lightllm.common.triton_utils.autotuner import autotune
6+
from lightllm.common.triton_utils.autotuner import autotune, closest_pow_of_2
77

88

99
@triton.jit
@@ -122,7 +122,7 @@ def get_static_key(q, k):
122122
kernel_name="rotary_emb_fwd:v1",
123123
configs_gen_func=get_test_configs,
124124
static_key_func=get_static_key,
125-
run_key_func=lambda q: q.shape[0],
125+
run_key_func=lambda q: closest_pow_of_2(q.shape[0]),
126126
mutates_args=["q", "k"],
127127
)
128128
@torch.no_grad()

lightllm/utils/envs_utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,13 @@ def get_kv_quant_calibration_inference_count():
149149
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000))
150150

151151

152-
def is_triton_autotune_enabled():
153-
# Whether Triton autotune is enabled (read-only check)
154-
mark = os.getenv("LIGHTLLM_TRITON_AUTOTUNE", "False").upper() in ["ON", "TRUE", "1"]
155-
return mark
152+
def get_triton_autotune_level():
153+
return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0))
156154

157155

158-
def disable_triton_autotune():
159-
# Disable Triton autotune (setter)
160-
os.environ["LIGHTLLM_TRITON_AUTOTUNE"] = "False"
156+
def set_triton_autotune_level(level: int):
157+
os.environ["LIGHTLLM_TRITON_AUTOTUNE_LEVEL"] = str(level)
158+
return
161159

162160

163161
g_model_init_done = False

0 commit comments

Comments
 (0)