Skip to content

Commit 075f65d

Browse files
committed
merge main
2 parents d45865c + ae0ac10 commit 075f65d

File tree

63 files changed

+2465
-174
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

63 files changed

+2465
-174
lines changed

README.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
2121
[English Docs](https://lightllm-en.readthedocs.io/en/latest/) | [中文文档](https://lightllm-cn.readthedocs.io/en/latest/) | [Blogs](https://modeltc.github.io/lightllm-blog/)
2222

2323
## News
24-
- [2025/05] LightLLM paper on constrained decoding accepted by [ACL25](https://arxiv.org/pdf/2506.03887) (Pre $^3$: Enabling Deterministic Pushdown Automata for Faster Structured LLM Generation). For a more accessible overview of the research with key insights and examples, check out our blog post: [LightLLM Blog](https://www.light-ai.top/lightllm-blog/2025/06/15/pre3.html)
24+
- [2025/09] 🔥 LightLLM [v1.1.0](https://www.light-ai.top/lightllm-blog/2025/09/03/lightllm.html) release!
25+
- [2025/08] Pre $^3$ achieves the outstanding paper award of [ACL2025](https://2025.aclweb.org/program/awards/).
26+
- [2025/05] LightLLM paper on constrained decoding accepted by [ACL2025](https://arxiv.org/pdf/2506.03887) (Pre $^3$: Enabling Deterministic Pushdown Automata for Faster Structured LLM Generation). For a more accessible overview of the research with key insights and examples, check out our blog post: [LightLLM Blog](https://www.light-ai.top/lightllm-blog/2025/06/15/pre3.html)
2527
- [2025/04] LightLLM paper on request scheduler published in [ASPLOS’25](https://dl.acm.org/doi/10.1145/3676641.3716011) (Past-Future Scheduler for LLM Serving under SLA Guarantees)
2628
- [2025/02] 🔥 LightLLM v1.0.0 release, achieving the **fastest DeepSeek-R1** serving performance on single H200 machine.
2729

@@ -90,6 +92,19 @@ We learned a lot from the following projects when developing LightLLM.
9092

9193
We have published a number of papers around components or features of LightLLM, if you use LightLLM in your work, please consider citing the relevant paper.
9294

95+
**constrained decoding**: accepted by [ACL2025](https://arxiv.org/pdf/2506.03887) and achieved the outstanding paper award.
96+
```bibtex
97+
@inproceedings{
98+
anonymous2025pre,
99+
title={Pre\${\textasciicircum}3\$: Enabling Deterministic Pushdown Automata for Faster Structured {LLM} Generation},
100+
author={Anonymous},
101+
booktitle={Submitted to ACL Rolling Review - February 2025},
102+
year={2025},
103+
url={https://openreview.net/forum?id=g1aBeiyZEi},
104+
note={under review}
105+
}
106+
```
107+
93108
**Request scheduler**: accepted by [ASPLOS’25](https://dl.acm.org/doi/10.1145/3676641.3716011):
94109
```bibtex
95110
@inproceedings{gong2025past,

docker/Dockerfile

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir
3939

4040
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4141

42-
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \
43-
cd flash-attention/hopper/ && python setup.py install
42+
RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl && \
43+
pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
4444

4545
RUN apt-get update && apt-get install -y libnuma-dev # for sgl_kernel
4646

docker/Dockerfile.deepep

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,8 @@ RUN pip install -r /lightllm/requirements.txt --no-cache-dir
3939

4040
RUN pip install --no-cache-dir vllm --pre --extra-index-url https://wheels.vllm.ai/nightly
4141

42-
RUN git clone https://github.com/ModelTC/LightKernel.git && cd LightKernel && pip install --no-deps -v . && \
43-
cd flash-attention/hopper/ && python setup.py install
42+
RUN pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/lightllm_kernel-0.1.0-cp310-cp310-linux_x86_64.whl && \
43+
pip install https://github.com/ModelTC/LightKernel/releases/download/v1.0.1/flash_attn_3-3.0.0b1-cp39-abi3-linux_x86_64.whl
4444

4545
RUN apt-get update && apt-get install -y libnuma-dev wget devscripts debhelper dh-make build-essential dkms
4646
RUN apt-get install -y ibverbs-providers infiniband-diags perftest rdma-core libibverbs-dev librdmacm-dev

lightllm/common/basemodel/basemodel.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import torch
88
import torch.nn.functional as F
99
from typing import final
10+
from tqdm import tqdm
1011

1112
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights
1213
from lightllm.common.basemodel.infer_struct import InferStateInfo
@@ -24,8 +25,10 @@
2425
from lightllm.utils.envs_utils import get_env_start_args
2526
from lightllm.distributed.communication_op import dist_group_manager
2627
from lightllm.common.basemodel.batch_objs import ModelInput, ModelOutput
28+
from lightllm.common.triton_utils.autotuner import AutotuneLevel
2729
from lightllm.utils.custom_kernel_utis import pad2dim_tensor_to_new_batch
28-
from lightllm.utils.envs_utils import set_model_init_status, is_triton_autotune_enabled, disable_triton_autotune
30+
from lightllm.utils.envs_utils import set_model_init_status
31+
from lightllm.common.triton_utils.autotuner import Autotuner
2932
from lightllm.utils.infer_utils import post_empty_cache
3033

3134
logger = init_logger(__name__)
@@ -731,12 +734,10 @@ def autotune_layers(self):
731734
@torch.no_grad()
732735
@post_empty_cache
733736
def _autotune_warmup(self):
734-
if not is_triton_autotune_enabled():
735-
return
736-
737+
Autotuner.start_autotune_warmup()
737738
torch.distributed.barrier()
738739

739-
warmup_lengths = [1, 8, 16, 64, 128, 256, 1024, 2048, 4096]
740+
warmup_lengths = [1, 8, 16, 32, 64, 100, 128, 256, 1024, 2048, 4096]
740741

741742
if self.batch_max_tokens not in warmup_lengths:
742743
warmup_lengths.append(self.batch_max_tokens)
@@ -747,9 +748,8 @@ def _autotune_warmup(self):
747748

748749
layer_num_bak = self.layers_num
749750
self.layers_num = self.autotune_layers()
750-
for input_len in warmup_lengths:
751+
for input_len in tqdm(warmup_lengths, desc="warming up"):
751752
try:
752-
logger.info(f"autotune warmup for length {input_len}")
753753
rand_gen = torch.Generator(device="cuda")
754754
rand_gen.manual_seed(input_len)
755755
dummy_input_ids = torch.randint(
@@ -784,7 +784,6 @@ def _autotune_warmup(self):
784784
self.mem_manager.free_all()
785785
gc.collect()
786786
torch.cuda.empty_cache()
787-
logger.info(f"autotune warmup for length {input_len} ok")
788787
except Exception as e:
789788
logger.warning(f"autotune warmup for length {input_len} failed: {str(e)}")
790789
logger.exception(str(e))
@@ -794,7 +793,7 @@ def _autotune_warmup(self):
794793
torch.cuda.empty_cache()
795794
self.layers_num = layer_num_bak
796795
torch.distributed.barrier()
797-
disable_triton_autotune()
796+
Autotuner.end_autotune_warmup()
798797

799798
@final
800799
@torch.no_grad()

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44
from typing import Optional, Tuple, List, Dict, Any
55
from lightllm.utils.dist_utils import get_global_world_size, get_global_rank, get_current_device_id
66
from .base_weight import BaseWeight
7-
from lightllm.common.fused_moe.grouped_fused_moe_ep import fused_experts_impl, masked_group_gemm
7+
from lightllm.common.fused_moe.grouped_fused_moe_ep import (
8+
fused_experts_impl,
9+
masked_group_gemm,
10+
_deepgemm_grouped_fp8_nt_contiguous,
11+
)
812
from lightllm.common.fused_moe.moe_silu_and_mul import silu_and_mul_fwd
913
from lightllm.distributed import dist_group_manager
1014
from lightllm.common.fused_moe.topk_select import select_experts
@@ -17,15 +21,11 @@
1721
)
1822
from lightllm.common.fused_moe.deepep_scatter_gather import ep_scatter, ep_gather
1923
from lightllm.common.basemodel.triton_kernel.redundancy_topk_ids_repair import redundancy_topk_ids_repair
20-
from lightllm.utils.envs_utils import is_triton_autotune_enabled
2124
from lightllm.utils.log_utils import init_logger
25+
from lightllm.common.triton_utils.autotuner import Autotuner
2226

23-
logger = init_logger(__name__)
2427

25-
try:
26-
import deep_gemm
27-
except:
28-
logger.warning("no deepep or deep_gemm")
28+
logger = init_logger(__name__)
2929

3030

3131
class FusedMoeWeightEP(BaseWeight):
@@ -335,7 +335,7 @@ def prefilled_group_gemm(
335335
# groupgemm (contiguous layout)
336336
gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype)
337337

338-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)
338+
_deepgemm_grouped_fp8_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)
339339

340340
# silu_and_mul_fwd + qaunt
341341
# TODO fused kernel
@@ -349,16 +349,14 @@ def prefilled_group_gemm(
349349
# groupgemm (contiguous layout)
350350
gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype)
351351

352-
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
353-
(qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices
354-
)
352+
_deepgemm_grouped_fp8_nt_contiguous((qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices)
355353
# gather and local reduce
356354
ep_gather(gemm_out_b, recv_topk_idx, recv_topk_weights, output_index, gather_out)
357355
else:
358356
######################################## warning ##################################################
359357
# here is used to match autotune feature, make moe model run same triton kernel in different rank.
360358
# in some special case, one rank will recv 0 token, so add a token to make it run triton kernel.
361-
if is_triton_autotune_enabled():
359+
if Autotuner.is_autotune_warmup():
362360
_gemm_out_a = torch.zeros((1, N), device=device, dtype=hidden_dtype)
363361
_silu_out = torch.zeros((1, N // 2), device=device, dtype=hidden_dtype)
364362
silu_and_mul_fwd(_gemm_out_a.view(-1, N), _silu_out)

lightllm/common/basemodel/triton_kernel/apply_penalty.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def _fwd_kernel_apply_penalty(
6161
cur_eos_logit_ptr = Logits + cur_batch * stride_logit_b + eos_id
6262
cur_eos_logit = tl.load(cur_eos_logit_ptr)
6363
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
64-
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
64+
cur_eos_logit = tl.where(mask_eos != 0, -10000000.0, cur_eos_logit)
6565
tl.store(cur_eos_logit_ptr, cur_eos_logit)
6666
return
6767

lightllm/common/basemodel/triton_kernel/apply_penalty_gpu_cache.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def _eos_penalty(
7676
cur_eos_logit_ptr = Logits + offs * stride_logit_b + eos_id
7777
cur_eos_logit = tl.load(cur_eos_logit_ptr, mask=mask, other=0.0)
7878
cur_eos_logit = cur_eos_logit + tl.abs(cur_eos_logit) * penalty_scale
79-
cur_eos_logit = tl.where(mask_eos, -10000000.0, cur_eos_logit)
79+
cur_eos_logit = tl.where(mask_eos != 0, -10000000.0, cur_eos_logit)
8080
tl.store(cur_eos_logit_ptr, cur_eos_logit, mask=mask)
8181
return
8282

lightllm/common/basemodel/triton_kernel/gather_token_id.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def _fwd_kernel_scatter(
1616
num_size,
1717
HAS_OUT_IS_NONE: tl.constexpr,
1818
BLOCK: tl.constexpr,
19+
OLD_VERSION_TRITON: tl.constexpr,
1920
):
2021
block_index = tl.program_id(0)
2122
block_range = block_index * BLOCK + tl.arange(0, BLOCK)
@@ -27,6 +28,8 @@ def _fwd_kernel_scatter(
2728

2829
if not HAS_OUT_IS_NONE:
2930
cur_has_out = tl.load(b_has_out + block_range, mask=block_mask, other=False)
31+
if OLD_VERSION_TRITON:
32+
cur_has_out = cur_has_out != 0
3033
tl.store(
3134
req_to_next_token_ids + cur_req_idx * req_to_next_token_ids_stride + cur_mtp_index,
3235
cur_next_token_id,
@@ -76,6 +79,7 @@ def scatter_token(
7679
num_size=batch_size,
7780
HAS_OUT_IS_NONE=b_has_out is None,
7881
BLOCK=BLOCK,
82+
OLD_VERSION_TRITON=triton.__version__ < "3.2.0",
7983
num_warps=num_warps,
8084
num_stages=1,
8185
)

lightllm/common/basemodel/triton_kernel/gen_sampling_params.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _token_id_counter_update_kernel(
125125
batch_size,
126126
HAS_MASK: tl.constexpr,
127127
BLOCK: tl.constexpr,
128+
OLD_VERSION_TRITON: tl.constexpr,
128129
):
129130

130131
block_start_index = tl.program_id(0) * BLOCK
@@ -136,6 +137,8 @@ def _token_id_counter_update_kernel(
136137

137138
if HAS_MASK:
138139
mask = tl.load(mask_ptr + offs, mask=loc_mask, other=False)
140+
if OLD_VERSION_TRITON:
141+
mask = mask != 0
139142
tl.atomic_add(
140143
req_to_out_token_id_counter_ptr + req_idx * counter_stride_m + token_ids * counter_stride_n,
141144
1,
@@ -170,6 +173,7 @@ def update_req_to_token_id_counter(
170173
batch_size=batch_size,
171174
HAS_MASK=has_mask,
172175
BLOCK=BLOCK,
176+
OLD_VERSION_TRITON=triton.__version__ < "3.2.0",
173177
num_warps=1,
174178
)
175179
return

lightllm/common/fused_moe/grouped_fused_moe.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def grouped_matmul_kernel(
332332
GROUP_SIZE_M: tl.constexpr,
333333
MUL_ROUTED_WEIGHT: tl.constexpr = False,
334334
NEED_K_MASK: tl.constexpr = True,
335+
NEED_TRANS: tl.constexpr = False,
335336
):
336337
pid = tl.program_id(0)
337338

@@ -367,13 +368,6 @@ def grouped_matmul_kernel(
367368
mask=token_mask,
368369
other=0,
369370
)
370-
if MUL_ROUTED_WEIGHT:
371-
a_m_scale = tl.load(
372-
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
373-
mask=token_mask,
374-
other=0.0,
375-
)
376-
377371
offs_bn = (tile_n_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % n
378372
offs_k = tl.arange(0, BLOCK_SIZE_K)
379373

@@ -387,7 +381,7 @@ def grouped_matmul_kernel(
387381
b_scale = tl.load(weight_scale_ptr + expert_id, eviction_policy="evict_last")
388382
ab_scale = a_scale * b_scale
389383

390-
if use_fp8_w8a8:
384+
if NEED_TRANS:
391385
a_ptrs = token_ptr + (a_m_index // topk_num)[None, :] * token_stride_0 + offs_k[:, None]
392386
b_ptrs = weights_ptr + weight_stride_0 * expert_id + offs_k[None, :] + offs_bn[:, None] * weight_stride_1
393387
accumulator = tl.zeros((BLOCK_SIZE_N, BLOCK_SIZE_M), dtype=tl.float32)
@@ -401,16 +395,20 @@ def grouped_matmul_kernel(
401395
# tl.multiple_of(a_ptrs, [16, 16])
402396
# tl.multiple_of(b_ptrs, [16, 16])
403397

404-
if use_fp8_w8a8:
398+
if NEED_TRANS:
405399
if NEED_K_MASK:
406-
a = tl.load(a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k), other=0.0)
400+
a = tl.load(
401+
a_ptrs, mask=(token_mask[None, :]) & (offs_k[:, None] < k - step_k * BLOCK_SIZE_K), other=0.0
402+
)
407403
b = tl.load(b_ptrs, mask=(offs_k[None, :] < k), other=0.0)
408404
else:
409405
a = tl.load(a_ptrs, mask=(token_mask[None, :]), other=0.0)
410406
b = tl.load(b_ptrs)
411407
else:
412408
if NEED_K_MASK:
413-
a = tl.load(a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k), other=0.0)
409+
a = tl.load(
410+
a_ptrs, mask=(token_mask[:, None]) & (offs_k[None, :] < k - step_k * BLOCK_SIZE_K), other=0.0
411+
)
414412
b = tl.load(b_ptrs, mask=(offs_k[:, None] < k), other=0.0)
415413
else:
416414
a = tl.load(a_ptrs, mask=(token_mask[:, None]), other=0.0)
@@ -421,24 +419,34 @@ def grouped_matmul_kernel(
421419
offs_ks = step_k * BLOCK_SIZE_K // block_size_k
422420
a_scale = tl.load(a_scale_ptrs + offs_ks, mask=token_mask, other=0.0)
423421
b_scale = tl.load(b_scale_ptrs + offs_ks * weight_scale_stride2)
424-
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
422+
if NEED_TRANS:
423+
accumulator += tl.dot(b, a) * b_scale[:, None] * a_scale[None, :]
424+
else:
425+
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
425426
else:
426-
accumulator = tl.dot(b, a, acc=accumulator)
427+
if NEED_TRANS:
428+
accumulator = tl.dot(b, a, acc=accumulator)
429+
else:
430+
accumulator = tl.dot(a, b, acc=accumulator)
427431
else:
428432
accumulator += tl.dot(a, b)
429433

430434
a_ptrs += BLOCK_SIZE_K
431435
b_ptrs += BLOCK_SIZE_K
432-
offs_k += BLOCK_SIZE_K
436+
437+
if NEED_TRANS:
438+
accumulator = accumulator.T
433439

434440
if use_fp8_w8a8:
435-
if block_size_k > 0 and block_size_n > 0:
436-
accumulator = accumulator.T
437-
else:
438-
accumulator = accumulator.T
441+
if not (block_size_k > 0 and block_size_n > 0):
439442
accumulator *= ab_scale
440443

441444
if MUL_ROUTED_WEIGHT:
445+
a_m_scale = tl.load(
446+
expert_to_weights_ptr + expert_id * expert_to_weights_stride0 + offs_am,
447+
mask=token_mask,
448+
other=0.0,
449+
)
442450
accumulator *= a_m_scale[:, None]
443451

444452
c = accumulator.to(compute_type)
@@ -478,13 +486,15 @@ def _get_grouped_matmul_configs():
478486
"GROUP_SIZE_M": gm,
479487
"num_warps": nw,
480488
"num_stages": ns,
489+
"NEED_TRANS": need_trans,
481490
}
482-
for ns in [1, 2, 3, 4, 5]
483-
for gm in [1, 2, 4, 8]
484-
for nw in [2, 4, 8]
491+
for ns in [2, 3, 4, 5]
492+
for gm in [1, 16, 32, 64]
493+
for nw in [4, 8]
485494
for bm in [16, 32, 64, 128]
486495
for bn in [16, 32, 64, 128]
487-
for bk in [16, 32, 64, 128]
496+
for bk in [32, 64, 128]
497+
for need_trans in [True, False]
488498
]
489499

490500

@@ -559,6 +569,9 @@ def grouped_matmul(
559569
GROUP_SIZE_M = run_config["GROUP_SIZE_M"]
560570
num_warps = run_config["num_warps"]
561571
num_stages = run_config["num_stages"]
572+
NEED_TRANS = run_config.get("NEED_TRANS", False)
573+
if not use_fp8_w8a8:
574+
assert NEED_TRANS is False, "only use_fp8_w8a8 mode can use NEED_TRANS to accelerate"
562575

563576
if block_size_k != 0:
564577
# 如果使用了 block wise 量化,分块大小不能超过 block size
@@ -638,6 +651,7 @@ def grouped_matmul(
638651
GROUP_SIZE_M=GROUP_SIZE_M,
639652
MUL_ROUTED_WEIGHT=mul_routed_weight,
640653
NEED_K_MASK=NEED_K_MASK,
654+
NEED_TRANS=NEED_TRANS,
641655
num_warps=num_warps,
642656
num_stages=num_stages,
643657
)

0 commit comments

Comments
 (0)