Skip to content

Commit 5ca9695

Browse files
authored
fix
1 parent 9f9a7d3 commit 5ca9695

File tree

4 files changed

+8
-15
lines changed

4 files changed

+8
-15
lines changed

lightllm/common/basemodel/layer_infer/cache_tensor_manager.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
from torch.storage import UntypedStorage
99
from dataclasses import field
1010
from lightllm.utils.log_utils import init_logger
11-
from lightllm.common.basemodel.triton_kernel.add_in_place import add_in_place
1211

1312
logger = init_logger(__name__)
1413

@@ -29,9 +28,6 @@ def custom_del(self: torch.Tensor):
2928
g_cache_manager.changed_ptr.add(storage_weak_ptr)
3029
return
3130

32-
def custom_add_(self, other, *, alpha=1):
33-
return add_in_place(self, other, alpha=alpha)
34-
3531
@dataclasses.dataclass
3632
class BufNode:
3733
inner_tensor: torch.Tensor

lightllm/common/basemodel/layer_weights/meta_weights/fused_moe_weight_ep.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -235,24 +235,25 @@ def prefilled_group_gemm(
235235
recv_topk_weights: torch.Tensor,
236236
hidden_dtype=torch.bfloat16,
237237
):
238+
device = recv_x[0].device
238239
w1, w1_scale = self.w1
239240
w2, w2_scale = self.w2
240241
_, K = recv_x[0].shape
241242
_, N, _ = w1.shape
242243
# scatter
243244
all_tokens = sum(num_recv_tokens_per_expert_list) # calcu padding all nums.
244245
# gather_out shape [recive_num_tokens, hidden]
245-
gather_out = torch.empty_like(recv_x[0], device=recv_x[0].device, dtype=hidden_dtype)
246+
gather_out = torch.empty_like(recv_x[0], device=device, dtype=hidden_dtype)
246247
if all_tokens > 0:
247248
input_tensor = [
248-
torch.empty((all_tokens, K), device=recv_x[0].device, dtype=recv_x[0].dtype),
249-
torch.empty((all_tokens, K // 128), device=recv_x[0].device, dtype=torch.float32),
249+
torch.empty((all_tokens, K), device=device, dtype=recv_x[0].dtype),
250+
torch.empty((all_tokens, K // 128), device=device, dtype=torch.float32),
250251
]
251252
# when m_indices is filled ok.
252253
# m_indices show token use which expert, example, [0, 0, 0, 0, .... 1, 1, 1, 1,...., cur_expert_num - 1, ..]
253254
# the count of 0 is num_recv_tokens_per_expert_list[0], the count of 1 is num_recv_tokens_per_expert_list[1]
254255
# ...
255-
m_indices = torch.empty(all_tokens, device=recv_x[0].device, dtype=torch.int32)
256+
m_indices = torch.empty(all_tokens, device=device, dtype=torch.int32)
256257
# output_index shape [recive_num_tokens, topk_num]
257258
# output_index use to show the token index in input_tensor
258259
output_index = torch.empty_like(recv_topk_idx)
@@ -276,19 +277,19 @@ def prefilled_group_gemm(
276277
)
277278
input_tensor[1] = tma_align_input_scale(input_tensor[1])
278279
# groupgemm (contiguous layout)
279-
gemm_out_a = torch.empty((all_tokens, N), device=recv_x[0].device, dtype=hidden_dtype)
280+
gemm_out_a = torch.empty((all_tokens, N), device=device, dtype=hidden_dtype)
280281

281282
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(input_tensor, (w1, w1_scale), gemm_out_a, m_indices)
282283

283284
# silu_and_mul_fwd + qaunt
284285
# TODO fused kernel
285-
silu_out = torch.empty((all_tokens, N // 2), device=recv_x[0].device, dtype=hidden_dtype)
286+
silu_out = torch.empty((all_tokens, N // 2), device=device, dtype=hidden_dtype)
286287

287288
silu_and_mul_fwd(gemm_out_a.view(-1, N), silu_out)
288289
qsilu_out, qsilu_out_scale = tma_aligned_quantize(silu_out)
289290

290291
# groupgemm (contiguous layout)
291-
gemm_out_b = torch.empty((all_tokens, K), device=recv_x[0].device, dtype=hidden_dtype)
292+
gemm_out_b = torch.empty((all_tokens, K), device=device, dtype=hidden_dtype)
292293

293294
deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
294295
(qsilu_out, qsilu_out_scale), (w2, w2_scale), gemm_out_b, m_indices

lightllm/common/fused_moe/grouped_fused_moe_ep.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from deep_ep import Buffer, EventOverlap
2323
import deep_gemm
2424

25-
# Set the number of SMs to use
2625
except:
2726
logger.warning("no deepep or deep_gemm")
2827

lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,6 @@
3131
from lightllm.utils.dist_utils import get_global_world_size
3232

3333

34-
# from lightllm.utils.custom_kernel_utis import torch_cat_3
35-
36-
3734
class Deepseek2TransformerLayerInfer(LlamaTransformerLayerInfer):
3835
def __init__(self, layer_num, network_config, mode=[]):
3936
self.tp_k_head_num_ = 1

0 commit comments

Comments
 (0)