Skip to content

Commit dfeb62d

Browse files
authored
Merge branch 'ModelTC:main' into mineru-adapt
2 parents ccb9d62 + 9f2f0cf commit dfeb62d

37 files changed

+542
-115
lines changed

docker/Dockerfile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir
3939

4040
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4141

42-
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
42+
# TODO: offline compile
43+
# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
4344

4445
RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel
4546

docker/Dockerfile.deepep

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir
3939

4040
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4141

42-
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
42+
# TODO: offline compile
43+
# RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v .
4344

4445
RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms
4546
RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev

lightllm/common/basemodel/basemodel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,9 @@
2424
from lightllm.utils.envs_utils import get_env_start_args
2525
from lightllm.distributed.communication_op import dist_group_manager
2626
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
27+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2728
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
28-
from lightllm.utils.envs_utils import set_model_init_status, is_triton_autotune_enabled, disable_triton_autotune
29+
from lightllm.utils.envs_utils import set_model_init_status, set_triton_autotune_level, get_triton_autotune_level
2930
from lightllm.utils.infer_utils import post_empty_cache
3031

3132
logger = init_logger(__name__)
@@ -731,7 +732,7 @@ def autotune_layers(self):
731732
@torch.no_grad()
732733
@post_empty_cache
733734
def _autotune_warmup(self):
734-
if not is_triton_autotune_enabled():
735+
if get_triton_autotune_level() not in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
735736
return
736737

737738
torch.distributed.barrier()
@@ -794,7 +795,7 @@ def _autotune_warmup(self):
794795
torch.cuda.empty_cache()
795796
self.layers_num = layer_num_bak
796797
torch.distributed.barrier()
797-
disable_triton_autotune()
798+
set_triton_autotune_level(AutotuneLevel.USE_AUTOTUNE_HIS_CONFIG)
798799

799800
@final
800801
@torch.no_grad()

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
)
1818
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1919
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
20-
from lightllm.utils.envs_utils import is_triton_autotune_enabled
20+
from lightllm.utils.envs_utils import get_triton_autotune_level
21+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2122
from lightllm.utils.log_utils import init_logger
2223

2324
logger = init_logger(__name__)
@@ -358,7 +359,7 @@ def prefilled_group_gemm(
358359
######################################## warning ##################################################
359360
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
360361
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
361-
if is_triton_autotune_enabled():
362+
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
362363
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
363364
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
364365
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def grouped_matmul_kernel(
332332
GROUP_SIZE_M: tl.constexpr,
333333
MUL_ROUTED_WEIGHT: tl.constexpr = False,
334334
NEED_K_MASK: tl.constexpr = True,
335+
NEED_TRANS: tl.constexpr = False,
335336
):
336337
pid = tl.program_id(0)
337338

@@ -367,13 +368,6 @@ def grouped_matmul_kernel(
367368
mask=token_mask,
368369
other=0,
369370
)
370-
if MUL_ROUTED_WEIGHT:
371-
a_m_scale = tl.load(
372-
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
373-
mask=token_mask,
374-
other=0.0,
375-
)
376-
377371
offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n
378372
offs_k = tl.arange(0, BLOCK_SIZE_K)
379373

@@ -387,7 +381,7 @@ def grouped_matmul_kernel(
387381
b_scale = tl.load(weight_scale_ptr + expert_id, eviction_policy="evict_last")
388382
ab_scale = a_scale * b_scale
389383

390-
if use_fp8_w8a8:
384+
if NEED_TRANS:
391385
a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None]
392386
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1
393387
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
@@ -401,16 +395,20 @@ def grouped_matmul_kernel(
401395
# tl.multiple_of(a_ptrs, [16, 16])
402396
# tl.multiple_of(b_ptrs, [16, 16])
403397

404-
if use_fp8_w8a8:
398+
if NEED_TRANS:
405399
if NEED_K_MASK:
406-
a = tl.load(a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k), other=0.0)
400+
a = tl.load(
401+
a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k - step_k * BLOCK_SIZE_K), other=0.0
402+
)
407403
b = tl.load(b_ptrs, mask=(offs_k[None, :] < k), other=0.0)
408404
else:
409405
a = tl.load(a_ptrs, mask=(token_mask[None, :]), other=0.0)
410406
b = tl.load(b_ptrs)
411407
else:
412408
if NEED_K_MASK:
413-
a = tl.load(a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k), other=0.0)
409+
a = tl.load(
410+
a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k - step_k * BLOCK_SIZE_K), other=0.0
411+
)
414412
b = tl.load(b_ptrs, mask=(offs_k[:, None] < k), other=0.0)
415413
else:
416414
a = tl.load(a_ptrs, mask=(token_mask[:, None]), other=0.0)
@@ -421,24 +419,34 @@ def grouped_matmul_kernel(
421419
offs_ks = step_k * BLOCK_SIZE_K // block_size_k
422420
a_scale = tl.load(a_scale_ptrs + offs_ks, mask=token_mask, other=0.0)
423421
b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2)
424-
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
422+
if NEED_TRANS:
423+
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
424+
else:
425+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
425426
else:
426-
accumulator = tl.dot(b, a, acc=accumulator)
427+
if NEED_TRANS:
428+
accumulator = tl.dot(b, a, acc=accumulator)
429+
else:
430+
accumulator = tl.dot(a, b, acc=accumulator)
427431
else:
428432
accumulator += tl.dot(a, b)
429433

430434
a_ptrs += BLOCK_SIZE_K
431435
b_ptrs += BLOCK_SIZE_K
432-
offs_k += BLOCK_SIZE_K
436+
437+
if NEED_TRANS:
438+
accumulator = accumulator.T
433439

434440
if use_fp8_w8a8:
435-
if block_size_k > 0 and block_size_n > 0:
436-
accumulator = accumulator.T
437-
else:
438-
accumulator = accumulator.T
441+
if not (block_size_k > 0 and block_size_n > 0):
439442
accumulator *= ab_scale
440443

441444
if MUL_ROUTED_WEIGHT:
445+
a_m_scale = tl.load(
446+
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
447+
mask=token_mask,
448+
other=0.0,
449+
)
442450
accumulator *= a_m_scale[:, None]
443451

444452
c = accumulator.to(compute_type)
@@ -478,13 +486,15 @@ def _get_grouped_matmul_configs():
478486
"GROUP_SIZE_M": gm,
479487
"num_warps": nw,
480488
"num_stages": ns,
489+
"NEED_TRANS": need_trans,
481490
}
482-
for ns in [1, 2, 3, 4, 5]
483-
for gm in [1, 2, 4, 8]
484-
for nw in [2, 4, 8]
491+
for ns in [2, 3, 4, 5]
492+
for gm in [1, 16, 32, 64]
493+
for nw in [4, 8]
485494
for bm in [16, 32, 64, 128]
486495
for bn in [16, 32, 64, 128]
487-
for bk in [16, 32, 64, 128]
496+
for bk in [32, 64, 128]
497+
for need_trans in [True, False]
488498
]
489499

490500

@@ -559,6 +569,9 @@ def grouped_matmul(
559569
GROUP_SIZE_M = run_config["GROUP_SIZE_M"]
560570
num_warps = run_config["num_warps"]
561571
num_stages = run_config["num_stages"]
572+
NEED_TRANS = run_config.get("NEED_TRANS", False)
573+
if not use_fp8_w8a8:
574+
assert NEED_TRANS is False, "only use_fp8_w8a8 mode can use NEED_TRANS to accelerate"
562575

563576
if block_size_k != 0:
564577
# 如果使用了 block wise 量化,分块大小不能超过 block size
@@ -638,6 +651,7 @@ def grouped_matmul(
638651
GROUP_SIZE_M=GROUP_SIZE_M,
639652
MUL_ROUTED_WEIGHT=mul_routed_weight,
640653
NEED_K_MASK=NEED_K_MASK,
654+
NEED_TRANS=NEED_TRANS,
641655
num_warps=num_warps,
642656
num_stages=num_stages,
643657
)

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
)
1515
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1616
from lightllm.utils.envs_utils import get_deepep_num_max_dispatch_tokens_per_rank
17-
from lightllm.utils.envs_utils import is_triton_autotune_enabled
17+
from lightllm.utils.envs_utils import get_triton_autotune_level
18+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
1819
import numpy as np
1920

2021
logger = init_logger(__name__)
@@ -191,7 +192,7 @@ def fused_experts_impl(
191192
######################################## warning ##################################################
192193
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
193194
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
194-
if is_triton_autotune_enabled():
195+
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
195196
_gemm_out_a = torch.zeros((1, N), device=hidden_states.device, dtype=hidden_states.dtype)
196197
_silu_out = torch.zeros((1, N // 2), device=hidden_states.device, dtype=hidden_states.dtype)
197198
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/fused_moe/grouped_topk.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def grouped_topk_kernel(
140140
offs_group = tl.arange(0, EXPERT_GROUP_NUM)
141141
offs_group_v = tl.arange(0, EXPERT_GROUP_SIZE)
142142
tl.store(scores_buffer_ptr + scores_stride_m * token_index + offs_n, scores, mask=offs_n < total_expert_num)
143+
tl.debug_barrier()
143144
group_scores = tl.load(
144145
scores_buffer_ptr
145146
+ scores_stride_token_m * token_index
@@ -174,7 +175,7 @@ def grouped_topk_kernel(
174175
mask_group_scores,
175176
mask=((offs_group < group_num)[:, None]) & ((offs_group_v < group_expert_num)[None, :]),
176177
) # [group, group_size]
177-
178+
tl.debug_barrier()
178179
mask_scores = tl.load(
179180
scores_buffer_ptr + scores_stride_m * token_index + offs_n, mask=offs_n < total_expert_num, other=-10000000.0
180181
)
@@ -227,6 +228,11 @@ def triton_grouped_topk(
227228

228229
assert total_expert_num % num_expert_group == 0
229230

231+
if token_num <= 256:
232+
num_warps = 4
233+
else:
234+
num_warps = 1
235+
230236
grouped_topk_kernel[(token_num,)](
231237
gating_output,
232238
*gating_output.stride(),
@@ -250,7 +256,7 @@ def triton_grouped_topk(
250256
EXPERT_GROUP_SIZE=triton.next_power_of_2(total_expert_num // num_expert_group),
251257
RENORMALIZE=renormalize,
252258
GROUP_SCORE_USED_TOPK_NUM=group_score_used_topk_num,
253-
num_warps=1,
259+
num_warps=num_warps,
254260
num_stages=1,
255261
)
256262
return out_topk_weights, out_topk_ids

lightllm/common/fused_moe/moe_kernel_configs.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def try_to_get_best_config(
4646
"BLOCK_SIZE_N": 32,
4747
"BLOCK_SIZE_K": 64,
4848
"GROUP_SIZE_M": 1,
49+
"NEED_TRANS": False,
4950
"num_warps": 4,
5051
"num_stages": 1,
5152
}
@@ -55,6 +56,7 @@ def try_to_get_best_config(
5556
"BLOCK_SIZE_N": 64,
5657
"BLOCK_SIZE_K": 32,
5758
"GROUP_SIZE_M": 8,
59+
"NEED_TRANS": False,
5860
"num_warps": 4,
5961
"num_stages": 1,
6062
}

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def _get_silu_and_mul_configs():
6868
{"BLOCK_M": bm, "BLOCK_N": bn, "num_warps": nw, "NUM_STAGES": ns}
6969
for ns in [1, 2, 4]
7070
for nw in [1, 4, 8]
71-
for bm in [32, 64, 128, 256]
71+
for bm in [1, 8, 32, 64, 128, 256]
7272
for bn in [32, 64, 128, 256]
7373
]
7474

lightllm/common/fused_moe/topk_select.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
from lightllm.utils.light_utils import light_ops
2424
from typing import Callable, List, Optional, Tuple
2525
from lightllm.common.fused_moe.softmax_topk import softmax_topk
26+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
27+
from lightllm.utils.envs_utils import get_triton_autotune_level
2628

2729
use_cuda_grouped_topk = os.getenv("LIGHTLLM_CUDA_GROUPED_TOPK", "False").upper() in ["ON", "TRUE", "1"]
2830

@@ -221,4 +223,12 @@ def select_experts(
221223
hidden_states=hidden_states, gating_output=router_logits, topk=top_k, renormalize=renormalize
222224
)
223225

226+
######################################## warning ##################################################
227+
# here is used to match autotune feature, make topk_ids more random
228+
if get_triton_autotune_level() in [AutotuneLevel.ADAPTIVE_AUTOTUNE, AutotuneLevel.FORCE_AUTOTUNE]:
229+
rand_gen = torch.Generator(device="cuda")
230+
rand_gen.manual_seed(router_logits.shape[0])
231+
router_logits = torch.randn(size=router_logits.shape, generator=rand_gen, dtype=torch.float32, device="cuda")
232+
_, topk_ids = torch.topk(router_logits, k=top_k, dim=1)
233+
224234
return topk_weights, topk_ids

0 commit comments

Comments
 (0)