Skip to content

Commit f7bf076

Browse files
committed
merge main
2 parents a0e1742 + 2d4f7d4 commit f7bf076

File tree

31 files changed

+406
-167
lines changed

31 files changed

+406
-167
lines changed

lightllm/common/basemodel/basemodel.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _init_kv_move_buffer(self):
175175

176176
def _check_mem_size(self):
177177
self.max_total_token_num = self.mem_manager.size
178-
assert self.max_seq_length < self.max_total_token_num
178+
assert self.max_seq_length <= self.max_total_token_num
179179
return
180180

181181
def _init_req_manager(self):

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_tp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import threading
44
from typing import Optional, Tuple, List, Dict, Any
55
from .base_weight import BaseWeight
6-
from lightllm.utils.dist_utils import get_global_rank, get_current_device_id
6+
from lightllm.utils.dist_utils import get_current_rank_in_dp, get_current_device_id
77
from lightllm.common.quantization import Quantcfg
88

99

@@ -37,7 +37,7 @@ def __init__(
3737
self.n_routed_experts = n_routed_experts
3838
self.split_inter_size = split_inter_size
3939
self.data_type_ = data_type
40-
self.tp_rank_ = get_global_rank()
40+
self.tp_rank_ = get_current_rank_in_dp()
4141
self.experts_up_projs = [None] * self.n_routed_experts
4242
self.experts_gate_projs = [None] * self.n_routed_experts
4343
self.experts_up_proj_scales = [None] * self.n_routed_experts

lightllm/common/basemodel/triton_kernel/destindex_copy_kv.py

Lines changed: 54 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@
66

77
@triton.jit
88
def _fwd_kernel_destindex_copy_kv(
9-
K, Dest_loc,
9+
K,
10+
Dest_loc,
1011
Out,
11-
stride_k_bs, stride_k_h, stride_k_d,
12-
stride_o_bs, stride_o_h, stride_o_d,
12+
stride_k_bs,
13+
stride_k_h,
14+
stride_k_d,
15+
stride_o_bs,
16+
stride_o_h,
17+
stride_o_d,
1318
head_num,
1419
BLOCK_DMODEL: tl.constexpr,
15-
BLOCK_HEAD: tl.constexpr
20+
BLOCK_HEAD: tl.constexpr,
1621
):
1722
cur_index = tl.program_id(0)
1823
offs_h = tl.arange(0, BLOCK_HEAD)
1924
offs_d = tl.arange(0, BLOCK_DMODEL)
2025

21-
dest_index = tl.load(Dest_loc + cur_index)
26+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
2227

2328
k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :]
2429
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
@@ -39,9 +44,15 @@ def destindex_copy_kv(K, DestLoc, Out):
3944
num_warps = 1
4045

4146
_fwd_kernel_destindex_copy_kv[grid](
42-
K, DestLoc, Out,
43-
K.stride(0), K.stride(1), K.stride(2),
44-
Out.stride(0), Out.stride(1), Out.stride(2),
47+
K,
48+
DestLoc,
49+
Out,
50+
K.stride(0),
51+
K.stride(1),
52+
K.stride(2),
53+
Out.stride(0),
54+
Out.stride(1),
55+
Out.stride(2),
4556
head_num,
4657
BLOCK_DMODEL=head_dim,
4758
BLOCK_HEAD=BLOCK_HEAD,
@@ -53,23 +64,35 @@ def destindex_copy_kv(K, DestLoc, Out):
5364

5465
@triton.jit
5566
def _fwd_kernel_destindex_copy_quantize_kv(
56-
K, Dest_loc, Out, Out_scale,
57-
stride_k_bs, stride_k_h, stride_k_d,
58-
stride_o_bs, stride_o_h, stride_o_d,
59-
stride_os_bs, stride_os_h, stride_os_d,
67+
K,
68+
Dest_loc,
69+
Out,
70+
Out_scale,
71+
stride_k_bs,
72+
stride_k_h,
73+
stride_k_d,
74+
stride_o_bs,
75+
stride_o_h,
76+
stride_o_d,
77+
stride_os_bs,
78+
stride_os_h,
79+
stride_os_d,
6080
head_num,
6181
BLOCK_DMODEL: tl.constexpr,
62-
BLOCK_HEAD: tl.constexpr
82+
BLOCK_HEAD: tl.constexpr,
6383
):
6484
cur_index = tl.program_id(0)
6585
offs_h = tl.arange(0, BLOCK_HEAD)
6686
offs_d = tl.arange(0, BLOCK_DMODEL)
6787

68-
dest_index = tl.load(Dest_loc + cur_index)
69-
src_data = tl.load(K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
70-
mask=offs_h[:, None] < head_num, other=0.0)
88+
dest_index = tl.load(Dest_loc + cur_index).to(tl.int64)
89+
src_data = tl.load(
90+
K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :],
91+
mask=offs_h[:, None] < head_num,
92+
other=0.0,
93+
)
7194
abs_data = tl.abs(src_data)
72-
data_scale = (tl.max(abs_data, axis=1) / 127.).to(Out_scale.dtype.element_ty)[:, None]
95+
data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None]
7396
q_src_data = (src_data / data_scale).to(tl.int8)
7497
o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :]
7598
os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None]
@@ -88,10 +111,19 @@ def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):
88111
num_warps = 1
89112

90113
_fwd_kernel_destindex_copy_quantize_kv[grid](
91-
K, DestLoc, Out, Out_scale,
92-
K.stride(0), K.stride(1), K.stride(2),
93-
Out.stride(0), Out.stride(1), Out.stride(2),
94-
Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),
114+
K,
115+
DestLoc,
116+
Out,
117+
Out_scale,
118+
K.stride(0),
119+
K.stride(1),
120+
K.stride(2),
121+
Out.stride(0),
122+
Out.stride(1),
123+
Out.stride(2),
124+
Out_scale.stride(0),
125+
Out_scale.stride(1),
126+
Out_scale.stride(2),
95127
head_num,
96128
BLOCK_DMODEL=head_dim,
97129
BLOCK_HEAD=BLOCK_HEAD,
@@ -149,6 +181,6 @@ def test2():
149181
print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32)))
150182

151183

152-
if __name__ == '__main__':
184+
if __name__ == "__main__":
153185
test1()
154186
test2()

lightllm/distributed/communication_op.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,11 @@
2727
from lightllm.utils.device_utils import has_nvlink
2828
from lightllm.utils.envs_utils import get_env_start_args, get_deepep_num_max_dispatch_tokens_per_rank
2929
from lightllm.utils.dist_utils import (
30-
get_current_device_id,
31-
get_node_world_size,
3230
get_global_world_size,
3331
get_dp_world_size,
3432
get_global_rank,
3533
get_current_rank_in_dp,
34+
create_new_group_for_current_dp,
3635
)
3736
from lightllm.utils.device_utils import get_device_sm_count
3837
from lightllm.utils.sgl_utils import HAS_SGL_KERNEL
@@ -63,17 +62,15 @@ def __init__(self):
6362
self.custom_reduce = None
6463
self.custom_gather = None
6564
self.dp_world_size = get_dp_world_size()
66-
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
67-
self.device_group = dist.new_group(ranks, backend="nccl")
65+
self.device_group = create_new_group_for_current_dp("nccl")
6866

6967
def init_custom_reduce(self) -> None:
7068
if not HAS_SGL_KERNEL or not has_nvlink() or self.dp_world_size not in [2, 4, 6, 8]:
7169
return
7270
args = get_env_start_args()
7371
if args.disable_custom_allreduce:
7472
return
75-
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
76-
cpu_group = dist.new_group(ranks, backend="gloo")
73+
cpu_group = create_new_group_for_current_dp("gloo")
7774
self.custom_reduce = CustomAllreduce(cpu_group, torch.cuda.current_device())
7875
logger.info("Enable Custom ALLReduce. You can disable it by settting --disable_custom_allreduce.")
7976

@@ -82,10 +79,10 @@ def init_custom_gather(self) -> None:
8279
return
8380

8481
args = get_env_start_args()
85-
if args.disable_custom_allgather:
82+
if not args.enable_custom_allgather:
8683
return
87-
ranks = list([get_global_rank() - get_current_rank_in_dp() + i for i in range(self.dp_world_size)])
88-
cpu_group = dist.new_group(ranks, backend="gloo")
84+
85+
cpu_group = create_new_group_for_current_dp("gloo")
8986
self.custom_gather = CustomAllgather(cpu_group, torch.cuda.current_device())
9087
logger.info("Enable Custom ALLGather. You can disable it by settting --disable_custom_allgather")
9188

lightllm/distributed/custom_all_gather.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,11 @@
3131
from lightllm.common.basemodel.layer_infer.cache_tensor_manager import g_cache_manager
3232

3333

34-
if light_ops is not None:
35-
light_ops.meta_size()
34+
try:
35+
if light_ops is not None:
36+
light_ops.meta_size()
37+
except:
38+
pass
3639

3740
logger = init_logger(__name__)
3841

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import triton.language as tl
55
import math
66
import torch.nn.functional as F
7-
8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
9-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
7+
from lightllm.utils.device_utils import is_tesla
108

119

1210
@triton.jit
@@ -165,7 +163,7 @@ def context_attention_fwd(
165163
softmax_scale,
166164
):
167165

168-
BLOCK = 128 if not TESLA else 64
166+
BLOCK = 128 if not is_tesla() else 64
169167
q_nope_dim = q_nope.shape[-1]
170168
q_rope_dim = q_rope.shape[-1]
171169
assert q_nope_dim == kv_nope.shape[-1]
@@ -174,9 +172,9 @@ def context_attention_fwd(
174172
assert q_rope_dim in {16, 32, 64, 128, 256}
175173

176174
if q_nope_dim >= 512:
177-
BLOCK = 32 if TESLA or CUDA_CAPABILITY[0] >= 9 else 64
175+
BLOCK = 32 if is_tesla() or torch.cuda.get_device_capability()[0] >= 9 else 64
178176
else:
179-
BLOCK = 128 if not TESLA else 64
177+
BLOCK = 128 if not is_tesla() else 64
180178

181179
if q_nope.dtype == torch.float32:
182180
BLOCK = BLOCK // 4
@@ -370,9 +368,9 @@ def context_attention_fwd_no_prompt_cache(
370368
assert q_rope_dim in {16, 32, 64, 128, 256}
371369

372370
if q_nope_dim >= 512:
373-
BLOCK = 32 if TESLA or CUDA_CAPABILITY[0] >= 9 else 64
371+
BLOCK = 32 if is_tesla() or torch.cuda.get_device_capability()[0] >= 9 else 64
374372
else:
375-
BLOCK = 128 if not TESLA else 64
373+
BLOCK = 128 if not is_tesla() else 64
376374

377375
if q_nope.dtype == torch.float32:
378376
BLOCK = BLOCK // 4

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_fp8.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import triton.language as tl
55
import math
66
import torch.nn.functional as F
7-
8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
9-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
7+
from lightllm.utils.device_utils import is_tesla
108

119

1210
@triton.jit
@@ -176,7 +174,7 @@ def context_attention_fwd_fp8(
176174
softmax_scale,
177175
):
178176

179-
BLOCK = 128 if not TESLA else 64
177+
BLOCK = 128 if not is_tesla() else 64
180178
q_nope_dim = q_nope.shape[-1]
181179
q_rope_dim = q_rope.shape[-1]
182180
assert q_nope_dim == kv_nope.shape[-1]
@@ -185,9 +183,9 @@ def context_attention_fwd_fp8(
185183
assert q_rope_dim in {16, 32, 64, 128, 256}
186184

187185
if q_nope_dim >= 512:
188-
BLOCK = 32 if TESLA or CUDA_CAPABILITY[0] >= 9 else 64
186+
BLOCK = 32 if is_tesla() or torch.cuda.get_device_capability()[0] >= 9 else 64
189187
else:
190-
BLOCK = 128 if not TESLA else 64
188+
BLOCK = 128 if not is_tesla() else 64
191189

192190
if q_nope.dtype == torch.float32:
193191
BLOCK = BLOCK // 4

lightllm/models/deepseek2/triton_kernel/context_flashattention_nopad_with_v.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@
44
import triton.language as tl
55
import math
66
import torch.nn.functional as F
7-
8-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
9-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
7+
from lightllm.utils.device_utils import is_tesla
108

119

1210
@triton.jit
@@ -148,7 +146,7 @@ def context_attention_fwd_with_v(
148146
softmax_scale,
149147
):
150148

151-
BLOCK = 128 if not TESLA else 64
149+
BLOCK = 128 if not is_tesla() else 64
152150
q_nope_dim = q_nope.shape[-1]
153151
q_rope_dim = q_rope.shape[-1]
154152
assert q_nope_dim == k_nope.shape[-1]
@@ -158,9 +156,9 @@ def context_attention_fwd_with_v(
158156
assert q_nope_dim == v.shape[-1]
159157

160158
if q_nope_dim >= 512:
161-
BLOCK = 64 if not TESLA else 32
159+
BLOCK = 64 if not is_tesla() else 32
162160
else:
163-
BLOCK = 128 if not TESLA else 64
161+
BLOCK = 128 if not is_tesla() else 64
164162

165163
if q_nope.dtype == torch.float32:
166164
BLOCK = BLOCK // 4

lightllm/models/deepseek2/triton_kernel/sample_kv.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
import triton
44
import triton.language as tl
55

6-
TESLA = "Tesla" in torch.cuda.get_device_name(0)
7-
CUDA_CAPABILITY = torch.cuda.get_device_capability()
6+
from lightllm.utils.device_utils import is_tesla
87

98

109
@triton.jit
@@ -77,14 +76,14 @@ def sample_kv(
7776
kv_scale=None,
7877
k_scale=None,
7978
):
80-
BLOCK = 128 if not TESLA else 64
79+
BLOCK = 128 if not is_tesla() else 64
8180

8281
nope_dim = kv_nope.shape[-1]
8382
rope_dim = kv_rope.shape[-1]
8483
if nope_dim >= 512:
85-
BLOCK = 64 if not TESLA else 32
84+
BLOCK = 64 if not is_tesla() else 32
8685
else:
87-
BLOCK = 128 if not TESLA else 64
86+
BLOCK = 128 if not is_tesla() else 64
8887

8988
batch = b_seq_len.shape[0]
9089

lightllm/models/internvl/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def __init__(self, tokenizer, model_cfg, **kwargs):
4848
def init_imageitem_extral_params(
4949
self, img: ImageItem, multi_params: MultimodalParams, sampling_params: SamplingParams
5050
):
51-
if sampling_params.image_max_patch_num >= 0:
51+
if sampling_params.image_max_patch_num > 0:
5252
img.extra_params["image_patch_max_num"] = sampling_params.image_max_patch_num
5353
return
5454
elif os.getenv("MAX_PATCH_NUM"):

0 commit comments

Comments
 (0)