44import triton .language as tl
55from .moe_silu_and_mul_config import MoeSiluAndMulKernelConfig
66
7-
87@triton .jit
9- def _silu_and_mul_kernel (
8+ def _silu_and_mul_kernel_fast (
109 input_ptr ,
1110 output_ptr ,
1211 stride_input_m ,
@@ -17,89 +16,46 @@ def _silu_and_mul_kernel(
1716 size_n ,
1817 BLOCK_M : tl .constexpr ,
1918 BLOCK_N : tl .constexpr ,
20- ):
21- stride_input_m = tl .cast (stride_input_m , dtype = tl .int64 )
22- stride_output_m = tl .cast (stride_output_m , dtype = tl .int64 )
23-
24- tid = tl .program_id (0 )
25- input_m_offsets = tid * BLOCK_M + tl .arange (0 , BLOCK_M )
26- output_m_offsets = tid * BLOCK_M + tl .arange (0 , BLOCK_M )
27-
28- pid = tl .program_id (1 )
29- input_n_offsets = pid * BLOCK_N + tl .arange (0 , BLOCK_N )
30- output_n_offsets = pid * BLOCK_N + tl .arange (0 , BLOCK_N )
31-
32- up_offsets = input_m_offsets [:, None ] * stride_input_m + (input_n_offsets [None , :] + size_n )
33- gate_offsets = input_m_offsets [:, None ] * stride_input_m + input_n_offsets [None , :]
34- res_offsets = output_m_offsets [:, None ] * stride_output_m + output_n_offsets [None , :]
35-
36- up = tl .load (
37- input_ptr + up_offsets ,
38- mask = (input_n_offsets < size_n )[None , :] * (input_m_offsets < size_m )[:, None ],
39- other = 0.0 ,
40- )
41- gate = tl .load (
42- input_ptr + gate_offsets ,
43- mask = (input_n_offsets < size_n )[None , :] * (input_m_offsets < size_m )[:, None ],
44- other = 0.0 ,
45- ).to (tl .float32 )
46-
47- gate = gate / (1 + tl .exp (- gate ))
48- gate = gate .to (input_ptr .dtype .element_ty )
49-
50- tl .store (
51- output_ptr + res_offsets ,
52- up * gate ,
53- mask = (output_n_offsets < size_n )[None , :] * (output_m_offsets < size_m )[:, None ],
54- )
55-
56-
57- @triton .jit
58- def _silu_and_mul_kernel_fast (
59- input_ptr ,
60- output_ptr ,
61- stride_input_m ,
62- stride_input_n ,
63- stride_output_m ,
64- stride_output_n ,
65- size_n ,
66- BLOCK_N : tl .constexpr ,
6719 NEED_MASK : tl .constexpr ,
6820):
6921 stride_input_m = tl .cast (stride_input_m , dtype = tl .int64 )
7022 stride_output_m = tl .cast (stride_output_m , dtype = tl .int64 )
7123
72- cur_batch = tl .program_id (0 )
73- pid = tl .program_id (1 )
74- n_offsets = pid * BLOCK_N + tl .arange (0 , BLOCK_N )
75-
76- up_offsets = cur_batch * stride_input_m + (n_offsets [None , :] + size_n )
77- gate_offsets = cur_batch * stride_input_m + n_offsets [None , :]
78- res_offsets = cur_batch * stride_output_m + n_offsets [None , :]
24+ m_block_index = tl .program_id (0 )
25+ n_block_index = tl .program_id (1 )
26+ n_offsets = n_block_index * BLOCK_N + tl .arange (0 , BLOCK_N )
27+ m_start_index = m_block_index * BLOCK_M
28+ m_end_index = (m_block_index + 1 ) * BLOCK_M
29+ m_end_index = tl .where (m_end_index < size_m , m_end_index , size_m )
7930 if NEED_MASK :
8031 mask = n_offsets [None , :] < size_n
8132 else :
82- mask = True
83-
84- up = tl .load (
85- input_ptr + up_offsets ,
86- mask = mask ,
87- other = 0.0 ,
88- )
89- gate = tl .load (
90- input_ptr + gate_offsets ,
91- mask = mask ,
92- other = 0.0 ,
93- ).to (tl .float32 )
94-
95- gate = gate / (1 + tl .exp (- gate ))
96- gate = gate .to (input_ptr .dtype .element_ty )
97-
98- tl .store (
99- output_ptr + res_offsets ,
100- up * gate ,
101- mask = mask ,
102- )
33+ mask = None
34+
35+ for m_index in range (m_start_index , m_end_index ):
36+ gate_offsets = m_index * stride_input_m + n_offsets [None , :]
37+ up_offsets = m_index * stride_input_m + (n_offsets [None , :] + size_n )
38+ out_offsets = m_index * stride_output_m + n_offsets [None , :]
39+
40+ up = tl .load (
41+ input_ptr + up_offsets ,
42+ mask = mask ,
43+ other = 0.0 ,
44+ )
45+ gate = tl .load (
46+ input_ptr + gate_offsets ,
47+ mask = mask ,
48+ other = 0.0 ,
49+ ).to (tl .float32 )
50+
51+ gate = gate / (1 + tl .exp (- gate ))
52+ gate = gate .to (input_ptr .dtype .element_ty )
53+
54+ tl .store (
55+ output_ptr + out_offsets ,
56+ up * gate ,
57+ mask = mask ,
58+ )
10359
10460
10561def silu_and_mul_fwd (input : torch .Tensor , output : torch .Tensor , ** run_config ):
@@ -116,26 +72,6 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
11672 if not run_config :
11773 run_config = MoeSiluAndMulKernelConfig .try_to_get_best_config (M = size_m , N = size_n , out_dtype = str (output .dtype ))
11874
119- if size_m <= 4096 :
120- BLOCK_N = run_config ["BLOCK_N" ]
121- grid = (
122- size_m ,
123- triton .cdiv (size_n , BLOCK_N ),
124- )
125- NEED_MASK = size_n % BLOCK_N != 0
126- _silu_and_mul_kernel_fast [grid ](
127- input ,
128- output ,
129- stride_input_m ,
130- stride_input_n ,
131- stride_output_m ,
132- stride_output_n ,
133- size_n ,
134- BLOCK_N = BLOCK_N ,
135- NEED_MASK = NEED_MASK ,
136- )
137- return
138-
13975 BLOCK_M = run_config ["BLOCK_M" ]
14076 BLOCK_N = run_config ["BLOCK_N" ]
14177 num_warps = run_config ["num_warps" ]
@@ -144,17 +80,19 @@ def silu_and_mul_fwd(input: torch.Tensor, output: torch.Tensor, **run_config):
14480 triton .cdiv (size_m , BLOCK_M ),
14581 triton .cdiv (size_n , BLOCK_N ),
14682 )
147- _silu_and_mul_kernel [grid ](
148- input ,
149- output ,
150- stride_input_m ,
151- stride_input_n ,
152- stride_output_m ,
153- stride_output_n ,
154- size_m ,
155- size_n ,
83+ NEED_MASK = (size_n % BLOCK_N ) != 0
84+ _silu_and_mul_kernel_fast [grid ](
85+ input_ptr = input ,
86+ output_ptr = output ,
87+ stride_input_m = stride_input_m ,
88+ stride_input_n = stride_input_n ,
89+ stride_output_m = stride_output_m ,
90+ stride_output_n = stride_output_n ,
91+ size_m = size_m ,
92+ size_n = size_n ,
15693 BLOCK_M = BLOCK_M ,
15794 BLOCK_N = BLOCK_N ,
95+ NEED_MASK = NEED_MASK ,
15896 num_warps = num_warps ,
15997 )
16098 return
0 commit comments