Skip to content

Commit fe5bda0

Browse files
author
none
committed
fix
1 parent 9ae19d8 commit fe5bda0

File tree

1 file changed

+44
-106
lines changed

1 file changed

+44
-106
lines changed

lightllm/common/fused_moe/moe_silu_and_mul.py

Lines changed: 44 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,8 @@
44
import triton.language as tl
55
from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
66

7-
87
@triton.jit
9-
def _silu_and_mul_kernel(
8+
def _silu_and_mul_kernel_fast(
109
input_ptr,
1110
output_ptr,
1211
stride_input_m,
@@ -17,89 +16,46 @@ def _silu_and_mul_kernel(
1716
size_n,
1817
BLOCK_M: tl.constexpr,
1918
BLOCK_N: tl.constexpr,
20-
):
21-
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
22-
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)
23-
24-
tid = tl.program_id(0)
25-
input_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)
26-
output_m_offsets = tid * BLOCK_M + tl.arange(0, BLOCK_M)
27-
28-
pid = tl.program_id(1)
29-
input_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
30-
output_n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
31-
32-
up_offsets = input_m_offsets[:, None] * stride_input_m + (input_n_offsets[None, :] + size_n)
33-
gate_offsets = input_m_offsets[:, None] * stride_input_m + input_n_offsets[None, :]
34-
res_offsets = output_m_offsets[:, None] * stride_output_m + output_n_offsets[None, :]
35-
36-
up = tl.load(
37-
input_ptr + up_offsets,
38-
mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],
39-
other=0.0,
40-
)
41-
gate = tl.load(
42-
input_ptr + gate_offsets,
43-
mask=(input_n_offsets < size_n)[None, :] * (input_m_offsets < size_m)[:, None],
44-
other=0.0,
45-
).to(tl.float32)
46-
47-
gate = gate / (1 + tl.exp(-gate))
48-
gate = gate.to(input_ptr.dtype.element_ty)
49-
50-
tl.store(
51-
output_ptr + res_offsets,
52-
up * gate,
53-
mask=(output_n_offsets < size_n)[None, :] * (output_m_offsets < size_m)[:, None],
54-
)
55-
56-
57-
@triton.jit
58-
def _silu_and_mul_kernel_fast(
59-
input_ptr,
60-
output_ptr,
61-
stride_input_m,
62-
stride_input_n,
63-
stride_output_m,
64-
stride_output_n,
65-
size_n,
66-
BLOCK_N: tl.constexpr,
6719
NEED_MASK: tl.constexpr,
6820
):
6921
stride_input_m = tl.cast(stride_input_m, dtype=tl.int64)
7022
stride_output_m = tl.cast(stride_output_m, dtype=tl.int64)
7123

72-
cur_batch = tl.program_id(0)
73-
pid = tl.program_id(1)
74-
n_offsets = pid * BLOCK_N + tl.arange(0, BLOCK_N)
75-
76-
up_offsets = cur_batch * stride_input_m + (n_offsets[None, :] + size_n)
77-
gate_offsets = cur_batch * stride_input_m + n_offsets[None, :]
78-
res_offsets = cur_batch * stride_output_m + n_offsets[None, :]
24+
m_block_index = tl.program_id(0)
25+
n_block_index = tl.program_id(1)
26+
n_offsets = n_block_index * BLOCK_N + tl.arange(0, BLOCK_N)
27+
m_start_index = m_block_index * BLOCK_M
28+
m_end_index = (m_block_index + 1) * BLOCK_M
29+
m_end_index = tl.where(m_end_index < size_m, m_end_index, size_m)
7930
if NEED_MASK:
8031
mask = n_offsets[None, :] < size_n
8132
else:
82-
mask = True
83-
84-
up = tl.load(
85-
input_ptr + up_offsets,
86-
mask=mask,
87-
other=0.0,
88-
)
89-
gate = tl.load(
90-
input_ptr + gate_offsets,
91-
mask=mask,
92-
other=0.0,
93-
).to(tl.float32)
94-
95-
gate = gate / (1 + tl.exp(-gate))
96-
gate = gate.to(input_ptr.dtype.element_ty)
97-
98-
tl.store(
99-
output_ptr + res_offsets,
100-
up * gate,
101-
mask=mask,
102-
)
33+
mask = None
34+
35+
for m_index in range(m_start_index, m_end_index):
36+
gate_offsets = m_index * stride_input_m + n_offsets[None, :]
37+
up_offsets = m_index * stride_input_m + (n_offsets[None, :] + size_n)
38+
out_offsets = m_index * stride_output_m + n_offsets[None, :]
39+
40+
up = tl.load(
41+
input_ptr + up_offsets,
42+
mask=mask,
43+
other=0.0,
44+
)
45+
gate = tl.load(
46+
input_ptr + gate_offsets,
47+
mask=mask,
48+
other=0.0,
49+
).to(tl.float32)
50+
51+
gate = gate / (1 + tl.exp(-gate))
52+
gate = gate.to(input_ptr.dtype.element_ty)
53+
54+
tl.store(
55+
output_ptr + out_offsets,
56+
up * gate,
57+
mask=mask,
58+
)
10359

10460

10561
def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
@@ -116,26 +72,6 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
11672
if not run_config:
11773
run_config = MoeSiluAndMulKernelConfig.try_to_get_best_config(M=size_m, N=size_n, out_dtype=str(output.dtype))
11874

119-
if size_m <= 4096:
120-
BLOCK_N = run_config["BLOCK_N"]
121-
grid = (
122-
size_m,
123-
triton.cdiv(size_n, BLOCK_N),
124-
)
125-
NEED_MASK = size_n % BLOCK_N != 0
126-
_silu_and_mul_kernel_fast[grid](
127-
input,
128-
output,
129-
stride_input_m,
130-
stride_input_n,
131-
stride_output_m,
132-
stride_output_n,
133-
size_n,
134-
BLOCK_N=BLOCK_N,
135-
NEED_MASK=NEED_MASK,
136-
)
137-
return
138-
13975
BLOCK_M = run_config["BLOCK_M"]
14076
BLOCK_N = run_config["BLOCK_N"]
14177
num_warps = run_config["num_warps"]
@@ -144,17 +80,19 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
14480
triton.cdiv(size_m, BLOCK_M),
14581
triton.cdiv(size_n, BLOCK_N),
14682
)
147-
_silu_and_mul_kernel[grid](
148-
input,
149-
output,
150-
stride_input_m,
151-
stride_input_n,
152-
stride_output_m,
153-
stride_output_n,
154-
size_m,
155-
size_n,
83+
NEED_MASK = (size_n % BLOCK_N) != 0
84+
_silu_and_mul_kernel_fast[grid](
85+
input_ptr=input,
86+
output_ptr=output,
87+
stride_input_m=stride_input_m,
88+
stride_input_n=stride_input_n,
89+
stride_output_m=stride_output_m,
90+
stride_output_n=stride_output_n,
91+
size_m=size_m,
92+
size_n=size_n,
15693
BLOCK_M=BLOCK_M,
15794
BLOCK_N=BLOCK_N,
95+
NEED_MASK=NEED_MASK,
15896
num_warps=num_warps,
15997
)
16098
return

0 commit comments

Comments
 (0)