2525from fast_llm .functional .triton .sparse_linear import output_sparse_matmul
2626from fast_llm .tensor import param_get_and_unset_is_zero
2727
28- # Triton requires global variables to be annotated with `constexpr`.
29- _TritonActivationType : tl_constexpr = ActivationType
30-
3128
3229@triton_jit ()
3330def triton_mlp_activation_forward_kernel (
@@ -50,18 +47,19 @@ def triton_mlp_activation_forward_kernel(
5047
5148 input_ = tl .load (input_ptr , mask = mask ).to (tl .float32 )
5249
53- if activation_type == _TritonActivationType .gelu :
50+ # Triton doesn't like enums, so we use str instead of ActivationType.
51+ if activation_type == "gelu" :
5452 tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_ )
5553 tanh = 1 - 2 / (1 + tl .exp (2 * tanh_input ))
5654 out = input_ * 0.5 * (1.0 + tanh )
57- elif activation_type == _TritonActivationType . silu :
55+ elif activation_type == " silu" :
5856 out = input_ / (1 + tl .exp (- input_ ))
59- elif activation_type == _TritonActivationType . relu :
57+ elif activation_type == " relu" :
6058 out = tl .where (input_ > 0 , input_ , 0 )
61- elif activation_type == _TritonActivationType . squared_relu :
59+ elif activation_type == " squared_relu" :
6260 relu_out = tl .where (input_ > 0 , input_ , 0 )
6361 out = relu_out * relu_out
64- elif activation_type == _TritonActivationType . identity :
62+ elif activation_type == " identity" :
6563 out = input_
6664 else :
6765 tl .static_assert (False , activation_type )
@@ -100,28 +98,29 @@ def triton_mlp_activation_backward_kernel(
10098 input_ = tl .load (input_ptr , mask = mask ).to (tl .float32 )
10199 output_grad = tl .load (grad_output_ptr + output_offsets , mask = mask ).to (tl .float32 )
102100
103- if activation_type == _TritonActivationType .gelu :
101+ # Triton doesn't like enums, so we use str instead of ActivationType.
102+ if activation_type == "gelu" :
104103 tanh_input = 0.79788456 * input_ * (1 + 0.044715 * input_ * input_ )
105104 tanh = 1 - 2 / (1 + tl .exp (2 * tanh_input ))
106105 grad = 0.5 * input_ * ((1 - tanh * tanh ) * (0.79788456 + 0.1070322243 * input_ * input_ )) + 0.5 * (1 + tanh )
107106 if gated or recompute :
108107 out = input_ * 0.5 * (1.0 + tanh )
109- elif activation_type == _TritonActivationType . silu :
108+ elif activation_type == " silu" :
110109 exp = tl .exp (- input_ )
111110 sigma = 1 / (1 + exp )
112111 grad = sigma * sigma + (1 + input_ ) / (2 + exp + 1 / exp )
113112 if gated or recompute :
114113 out = input_ * sigma
115- elif activation_type == _TritonActivationType . relu :
114+ elif activation_type == " relu" :
116115 grad = tl .where (input_ > 0 , 1 , 0 )
117116 if gated or recompute :
118117 out = tl .where (input_ > 0 , input_ , 0 )
119- elif activation_type == _TritonActivationType . squared_relu :
118+ elif activation_type == " squared_relu" :
120119 relu_out = tl .where (input_ > 0 , input_ , 0 )
121120 grad = 2 * relu_out
122121 if gated or recompute :
123122 out = relu_out * relu_out
124- elif activation_type == _TritonActivationType . identity :
123+ elif activation_type == " identity" :
125124 grad = 1
126125 if gated or recompute :
127126 out = input_
0 commit comments