Skip to content

Commit 1030efb

Browse files
committed
refactor
1 parent 452e423 commit 1030efb

16 files changed

+48
-1365
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,10 @@ def __init__(self, kvargs):
8181

8282
self._init_datatype()
8383
self._init_config()
84+
85+
if os.environ.get("LIGHTLLM_TRITON_AUTOTUNE", "0") == "1":
86+
self.layers_num = self.autotune_layers()
87+
8488
self._verify_must()
8589
self._verify_params()
8690
self._init_quant()
@@ -744,8 +748,6 @@ def _autotune_warmup(self):
744748

745749
warmup_lengths.sort(reverse=True)
746750

747-
layer_num_bak = self.layers_num
748-
self.layers_num = self.autotune_layers()
749751
for input_len in warmup_lengths:
750752
try:
751753
logger.info(f"autotune warmup for length {input_len}")
@@ -777,14 +779,16 @@ def _autotune_warmup(self):
777779
del model_output
778780
self.req_manager.free_all()
779781
self.mem_manager.free_all()
782+
torch.cuda.empty_cache()
780783
logger.info(f"autotune warmup for length {input_len} ok")
781784
except Exception as e:
782785
logger.warning(f"autotune warmup for length {input_len} failed: {str(e)}")
783786
self.req_manager.free_all()
784787
self.mem_manager.free_all()
785-
self.layers_num = layer_num_bak
788+
torch.cuda.empty_cache()
786789
torch.distributed.barrier()
787-
os.environ["LIGHTLLM_TRITON_AUTOTUNE"] = "0"
790+
logger.info("autotune warmup done, exit!")
791+
exit(0)
788792

789793
@final
790794
@torch.no_grad()

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def get_grouped_matmul_static_key(
478478
"BLOCK_SIZE_N": bn,
479479
"BLOCK_SIZE_K": bk,
480480
"GROUP_SIZE_M": gm,
481-
"NUM_WARPS": nw,
481+
"num_warps": nw,
482482
"NUM_STAGE": ns,
483483
}
484484
for ns in [1, 2, 3, 4, 5]
@@ -493,7 +493,7 @@ def get_grouped_matmul_static_key(
493493
"BLOCK_SIZE_N": 64,
494494
"BLOCK_SIZE_K": 32,
495495
"GROUP_SIZE_M": 8,
496-
"NUM_WARPS": 4,
496+
"num_warps": 4,
497497
"NUM_STAGE": 1,
498498
},
499499
static_key_func=get_grouped_matmul_static_key,
@@ -550,7 +550,7 @@ def grouped_matmul(
550550
BLOCK_SIZE_N = run_config["BLOCK_SIZE_N"]
551551
BLOCK_SIZE_K = run_config["BLOCK_SIZE_K"]
552552
GROUP_SIZE_M = run_config["GROUP_SIZE_M"]
553-
num_warps = run_config["NUM_WARPS"]
553+
num_warps = run_config["num_warps"]
554554
num_stages = run_config["NUM_STAGE"]
555555

556556
if block_size_k != 0:

lightllm/common/triton_utils/all_kernel_configs/triton_3.3.1/NVIDIA H200/grouped_matmul:v1/K=192,N=4096,expert_num=128,mul_routed_weight=True,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=False.json

Lines changed: 0 additions & 122 deletions
This file was deleted.

lightllm/common/triton_utils/all_kernel_configs/triton_3.3.1/NVIDIA H200/grouped_matmul:v1/K=192,N=4096,expert_num=128,mul_routed_weight=True,out_dtype=torch.bfloat16,topk_num=1,use_fp8_w8a8=True.json

Lines changed: 0 additions & 122 deletions
This file was deleted.

0 commit comments

Comments
 (0)