Skip to content

Commit 6475617

Browse files
sufubaohiworldwzj
andauthored
force to warmup triton autotune configs in start. (#1043)
Co-authored-by: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com>
1 parent 1233b02 commit 6475617

File tree

6 files changed

+58
-25
lines changed

6 files changed

+58
-25
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn.functional as F
99
from typing import final
10+
from tqdm import tqdm
1011

1112
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
1213
from lightllm.common.basemodel.infer_struct import InferStateInfo
@@ -26,7 +27,8 @@
2627
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2728
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2829
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
29-
from lightllm.utils.envs_utils import set_model_init_status, set_triton_autotune_level, get_triton_autotune_level
30+
from lightllm.utils.envs_utils import set_model_init_status
31+
from lightllm.common.triton_utils.autotuner import Autotuner
3032
from lightllm.utils.infer_utils import post_empty_cache
3133

3234
logger = init_logger(__name__)
@@ -732,9 +734,7 @@ def autotune_layers(self):
732734
@torch.no_grad()
733735
@post_empty_cache
734736
def _autotune_warmup(self):
735-
if get_triton_autotune_level() not in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
736-
return
737-
737+
Autotuner.start_autotune_warmup()
738738
torch.distributed.barrier()
739739

740740
warmup_lengths = [1, 8, 16, 32, 64, 100, 128, 256, 1024, 2048, 4096]
@@ -748,9 +748,8 @@ def _autotune_warmup(self):
748748

749749
layer_num_bak = self.layers_num
750750
self.layers_num = self.autotune_layers()
751-
for input_len in warmup_lengths:
751+
for input_len in tqdm(warmup_lengths, desc="warming up"):
752752
try:
753-
logger.info(f"autotune warmup for length {input_len}")
754753
rand_gen = torch.Generator(device="cuda")
755754
rand_gen.manual_seed(input_len)
756755
dummy_input_ids = torch.randint(
@@ -785,7 +784,6 @@ def _autotune_warmup(self):
785784
self.mem_manager.free_all()
786785
gc.collect()
787786
torch.cuda.empty_cache()
788-
logger.info(f"autotune warmup for length {input_len} ok")
789787
except Exception as e:
790788
logger.warning(f"autotune warmup for length {input_len} failed: {str(e)}")
791789
logger.exception(str(e))
@@ -795,7 +793,7 @@ def _autotune_warmup(self):
795793
torch.cuda.empty_cache()
796794
self.layers_num = layer_num_bak
797795
torch.distributed.barrier()
798-
set_triton_autotune_level(AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG)
796+
Autotuner.end_autotune_warmup()
799797

800798
@final
801799
@torch.no_grad()

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
)
2222
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
2323
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
24-
from lightllm.utils.envs_utils import get_triton_autotune_level
25-
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2624
from lightllm.utils.log_utils import init_logger
25+
from lightllm.common.triton_utils.autotuner import Autotuner
26+
2727

2828
logger = init_logger(__name__)
2929

@@ -356,7 +356,7 @@ def prefilled_group_gemm(
356356
######################################## warning ##################################################
357357
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
358358
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
359-
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
359+
if Autotuner.is_autotune_warmup():
360360
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
361361
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
362362
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,7 @@
1414
)
1515
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1616
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
17-
from lightllm.utils.envs_utils import get_triton_autotune_level
18-
from lightllm.common.triton_utils.autotuner import AutotuneLevel
17+
from lightllm.common.triton_utils.autotuner import Autotuner
1918
import numpy as np
2019

2120
logger = init_logger(__name__)
@@ -190,7 +189,7 @@ def fused_experts_impl(
190189
######################################## warning ##################################################
191190
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
192191
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
193-
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
192+
if Autotuner.is_autotune_warmup():
194193
_gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
195194
_silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
196195
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/topk_select.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,7 @@
2323
from lightllm.utils.light_utils import light_ops
2424
from typing import Callable, List, Optional, Tuple
2525
from lightllm.common.fused_moe.softmax_topk import softmax_topk
26-
from lightllm.common.triton_utils.autotuner import AutotuneLevel
27-
from lightllm.utils.envs_utils import get_triton_autotune_level
26+
from lightllm.common.triton_utils.autotuner import Autotuner
2827

2928
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
3029

@@ -225,7 +224,7 @@ def select_experts(
225224

226225
######################################## warning ##################################################
227226
# here is used to match autotune feature, make topk_ids more random
228-
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
227+
if Autotuner.is_autotune_warmup():
229228
rand_gen = torch.Generator(device="cuda")
230229
rand_gen.manual_seed(router_logits.shape[0])
231230
router_logits = torch.randn(size=router_logits.shape, generator=rand_gen, dtype=torch.float32, device="cuda")

lightllm/common/triton_utils/autotuner.py

Lines changed: 44 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ def decorator(fn):
7777

7878

7979
class Autotuner:
80+
_autotune_warmup: bool = False
81+
82+
@staticmethod
83+
def start_autotune_warmup():
84+
Autotuner._autotune_warmup = True
85+
return
86+
87+
@staticmethod
88+
def end_autotune_warmup():
89+
Autotuner._autotune_warmup = False
90+
return
91+
92+
@staticmethod
93+
def is_autotune_warmup():
94+
return Autotuner._autotune_warmup
95+
8096
def __init__(
8197
self,
8298
fn,
@@ -104,6 +120,7 @@ def __init__(
104120
self.run_key_distance_func = run_key_distance_func
105121
self.cached_configs = {}
106122
self.fast_match_configs = collections.defaultdict(dict)
123+
self.warmuped_configs_set = set()
107124
self.arg_names = [param.name for param in inspect.signature(self.fn).parameters.values()]
108125
self._argname_to_pos = {name: idx for idx, name in enumerate(self.arg_names)}
109126
self._pos_to_argname = {idx: name for idx, name in enumerate(self.arg_names)}
@@ -139,7 +156,13 @@ def __call__(self, *args, **kwargs):
139156
run_key = str(self._run_key(*args, **kwargs))
140157

141158
# Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
142-
self._try_load_cache(static_key)
159+
if self._try_load_cache(static_key) or Autotuner.is_autotune_warmup():
160+
all_configs = self.cached_configs.get(static_key, {})
161+
for run_config in all_configs.values():
162+
# warmup all configs
163+
_copy_kwargs = kwargs.copy()
164+
_copy_kwargs["run_config"] = run_config
165+
self.kernel_warmup(static_key, *args, **_copy_kwargs)
143166

144167
if static_key not in self.cached_configs and autotune_level == AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG:
145168
if (dist.is_initialized() and get_current_rank_in_node() == 0) or not dist.is_initialized():
@@ -150,7 +173,10 @@ def __call__(self, *args, **kwargs):
150173
)
151174
self.cached_configs[static_key] = {}
152175

153-
if autotune_level in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
176+
if (
177+
autotune_level in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]
178+
and Autotuner.is_autotune_warmup()
179+
):
154180
need_tuning = (autotune_level == AutotuneLevel.FORCE_AUTOTUNE) or (
155181
run_key not in self.cached_configs.get(static_key, {})
156182
)
@@ -185,13 +211,28 @@ def __call__(self, *args, **kwargs):
185211

186212
def _try_load_cache(self, static_key):
187213
if static_key in self.cached_configs:
188-
return
214+
return False
189215

190216
cache_file = os.path.join(self.cache_dir, KernelConfigs.get_config_file_name(static_key))
191217
if os.path.exists(cache_file):
192218
logger.info(f"Loading cached configs for {self.kernel_name} - {static_key}")
193219
with open(cache_file, "rb") as f:
194220
self.cached_configs[static_key] = orjson.loads(f.read())
221+
return True
222+
223+
def kernel_warmup(self, static_key, *args, **kwargs):
224+
new_args, new_kwargs, origin_list, new_list = self._mutate_args_clone(args, kwargs)
225+
run_config = kwargs.get("run_config", {})
226+
hash_key = str(frozendict(run_config)) + str(static_key)
227+
if hash_key in self.warmuped_configs_set:
228+
return
229+
try:
230+
self.fn(*new_args, **new_kwargs)
231+
self.warmuped_configs_set.add(hash_key)
232+
except:
233+
pass
234+
finally:
235+
self._recover_mutated_args(origin_list=origin_list, new_list=new_list)
195236
return
196237

197238
def _bench(self, *args, n_repeat=3, n_retries=3, **kwargs):

lightllm/utils/envs_utils.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -149,15 +149,11 @@ def get_kv_quant_calibration_inference_count():
149149
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000))
150150

151151

152+
@lru_cache(maxsize=None)
152153
def get_triton_autotune_level():
153154
return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0))
154155

155156

156-
def set_triton_autotune_level(level: int):
157-
os.environ["LIGHTLLM_TRITON_AUTOTUNE_LEVEL"] = str(level)
158-
return
159-
160-
161157
g_model_init_done = False
162158

163159

0 commit comments

Comments
 (0)