Skip to content

Commit f7f8a54

Browse files
committed
clean code
1 parent e4ce003 commit f7f8a54

File tree

4 files changed

+37
-26
lines changed

4 files changed

+37
-26
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@
2727
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
2828
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2929
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
30-
from lightllm.utils.envs_utils import set_model_init_status, set_triton_autotune_level, set_triton_autotune_warmup
30+
from lightllm.utils.envs_utils import set_model_init_status, set_triton_autotune_level, get_triton_autotune_level
31+
from lightllm.common.triton_utils.autotuner import Autotuner
3132
from lightllm.utils.infer_utils import post_empty_cache
3233

3334
logger = init_logger(__name__)
@@ -733,7 +734,7 @@ def autotune_layers(self):
733734
@torch.no_grad()
734735
@post_empty_cache
735736
def _autotune_warmup(self):
736-
set_triton_autotune_warmup(1)
737+
Autotuner.start_autotune_warmup()
737738
torch.distributed.barrier()
738739

739740
warmup_lengths = [1, 8, 16, 32, 64, 100, 128, 256, 1024, 2048, 4096]
@@ -792,8 +793,7 @@ def _autotune_warmup(self):
792793
torch.cuda.empty_cache()
793794
self.layers_num = layer_num_bak
794795
torch.distributed.barrier()
795-
set_triton_autotune_level(AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG)
796-
set_triton_autotune_warmup(0)
796+
Autotuner.end_autotune_warmup()
797797

798798
@final
799799
@torch.no_grad()

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def grouped_matmul(
576576
if block_size_k != 0:
577577
# 如果使用了 block wise 量化,分块大小不能超过 block size
578578
BLOCK_SIZE_K = min(BLOCK_SIZE_K, block_size_k)
579-
BLOCK_SIZE_K = triton.next_power_of_2(BLOCK_SIZE_K)
579+
assert BLOCK_SIZE_K == triton.next_power_of_2(BLOCK_SIZE_K)
580580

581581
if use_fp8_w8a8:
582582
# 当权重使用 block wise 量化时,激活也使用 per token, group size 量化

lightllm/common/triton_utils/autotuner.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,22 @@ def decorator(fn):
7777

7878

7979
class Autotuner:
80+
_autotune_warmup: bool = False
81+
82+
@staticmethod
83+
def start_autotune_warmup():
84+
Autotuner._autotune_warmup = True
85+
return
86+
87+
@staticmethod
88+
def end_autotune_warmup():
89+
Autotuner._autotune_warmup = False
90+
return
91+
92+
@staticmethod
93+
def is_autotune_warmup():
94+
return Autotuner._autotune_warmup
95+
8096
def __init__(
8197
self,
8298
fn,
@@ -139,12 +155,12 @@ def __call__(self, *args, **kwargs):
139155
run_key = str(self._run_key(*args, **kwargs))
140156

141157
# Lazy load the cached configs in lightllm/common/triton_utils/autotune_kernel_configs
142-
if self._try_load_cache(static_key) or get_triton_autotune_warmup():
158+
if self._try_load_cache(static_key) or Autotuner.is_autotune_warmup():
143159
all_configs = self.cached_configs.get(static_key, {})
144160
for run_config in all_configs.values():
145-
# warmup
161+
# warmup all configs
146162
kwargs["run_config"] = run_config
147-
self._bench(*args, n_repeat=1, n_retries=1, warmup=True, **kwargs)
163+
self.kernel_warmup(*args, **kwargs)
148164

149165
if static_key not in self.cached_configs and autotune_level == AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG:
150166
if (dist.is_initialized() and get_current_rank_in_node() == 0) or not dist.is_initialized():
@@ -199,7 +215,18 @@ def _try_load_cache(self, static_key):
199215
self.cached_configs[static_key] = orjson.loads(f.read())
200216
return True
201217

202-
def _bench(self, *args, n_repeat=3, n_retries=3, warmup=False, **kwargs):
218+
def kernel_warmup(self, *args, **kwargs):
219+
new_args, new_kwargs, origin_list, new_list = self._mutate_args_clone(args, kwargs)
220+
221+
try:
222+
self.fn(*new_args, **new_kwargs)
223+
except:
224+
pass
225+
finally:
226+
self._recover_mutated_args(origin_list=origin_list, new_list=new_list)
227+
return
228+
229+
def _bench(self, *args, n_repeat=3, n_retries=3, **kwargs):
203230
from triton.compiler.errors import CompileTimeAssertionFailure
204231
from triton.runtime.errors import OutOfResources, PTXASError
205232

@@ -209,16 +236,13 @@ def kernel_call():
209236
try:
210237
self.fn(*new_args, **new_kwargs)
211238
except Exception as e:
212-
print(f"error: {e}")
213239
raise e
214240
finally:
215241
self._recover_mutated_args(origin_list=origin_list, new_list=new_list)
216242

217243
try:
218244
# warmup
219245
kernel_call()
220-
if warmup:
221-
return
222246

223247
torch.cuda.current_stream().synchronize()
224248
g = torch.cuda.CUDAGraph()

lightllm/utils/envs_utils.py

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -149,24 +149,11 @@ def get_kv_quant_calibration_inference_count():
149149
return int(os.getenv("LIGHTLLM_KV_QUANT_CALIBRARTION_INFERENCE_COUNT", 4000))
150150

151151

152+
@lru_cache(maxsize=None)
152153
def get_triton_autotune_level():
153154
return int(os.getenv("LIGHTLLM_TRITON_AUTOTUNE_LEVEL", 0))
154155

155156

156-
def set_triton_autotune_level(level: int):
157-
os.environ["LIGHTLLM_TRITON_AUTOTUNE_LEVEL"] = str(level)
158-
return
159-
160-
161-
def set_triton_autotune_warmup(warmup: int):
162-
os.environ["LIGHTLLM_TRITON_AUTOTUNE_WARMUP"] = str(warmup)
163-
return
164-
165-
166-
def get_triton_autotune_warmup():
167-
return os.getenv("LIGHTLLM_TRITON_AUTOTUNE_WARMUP", "0") == "1"
168-
169-
170157
g_model_init_done = False
171158

172159

0 commit comments

Comments
 (0)