|
| 1 | +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. |
| 2 | +# Copyright 2022 The HuggingFace Team. All rights reserved. |
| 3 | +# |
| 4 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 5 | +# you may not use this file except in compliance with the License. |
| 6 | +# You may obtain a copy of the License at |
| 7 | +# |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Unless required by applicable law or agreed to in writing, software |
| 11 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 12 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 13 | +# See the License for the specific language governing permissions and |
| 14 | +# limitations under the License. |
| 15 | + |
| 16 | +import math |
| 17 | +from collections import OrderedDict |
| 18 | + |
| 19 | +import paddle |
| 20 | +import paddle.nn.functional as F |
| 21 | +from paddle import Tensor, nn |
| 22 | + |
| 23 | + |
| 24 | +class NewGELUActivation(nn.Layer): |
| 25 | + """ |
| 26 | + Implementation of the GELU activation function currently in Google BERT repo (identical to OpenAI GPT). Also see |
| 27 | + the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 |
| 28 | + """ |
| 29 | + |
| 30 | + def forward(self, input: Tensor) -> Tensor: |
| 31 | + return ( |
| 32 | + 0.5 * input * (1.0 + paddle.tanh(math.sqrt(2.0 / math.pi) * (input + 0.044715 * paddle.pow(input, 3.0)))) |
| 33 | + ) |
| 34 | + |
| 35 | + |
| 36 | +class GELUActivation(nn.Layer): |
| 37 | + """ |
| 38 | + Original Implementation of the GELU activation function in Google BERT repo when initially created. For |
| 39 | + information: OpenAI GPT's GELU is slightly different (and gives slightly different results): 0.5 * x * (1 + |
| 40 | + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) This is now written in C in nn.functional |
| 41 | + Also see the Gaussian Error Linear Units paper: https://arxiv.org/abs/1606.08415 |
| 42 | + """ |
| 43 | + |
| 44 | + def __init__(self, use_gelu_python: bool = False): |
| 45 | + super().__init__() |
| 46 | + if use_gelu_python: |
| 47 | + self.act = self._gelu_python |
| 48 | + else: |
| 49 | + self.act = nn.functional.gelu |
| 50 | + |
| 51 | + def _gelu_python(self, input: Tensor) -> Tensor: |
| 52 | + return input * 0.5 * (1.0 + paddle.erf(input / math.sqrt(2.0))) |
| 53 | + |
| 54 | + def forward(self, input: Tensor) -> Tensor: |
| 55 | + return self.act(input) |
| 56 | + |
| 57 | + |
| 58 | +class FastGELUActivation(nn.Layer): |
| 59 | + """ |
| 60 | + Applies GELU approximation that is slower than QuickGELU but more accurate. See: https://github.com/hendrycks/GELUs |
| 61 | + """ |
| 62 | + |
| 63 | + def forward(self, input: Tensor) -> Tensor: |
| 64 | + return 0.5 * input * (1.0 + paddle.tanh(input * 0.7978845608 * (1.0 + 0.044715 * input * input))) |
| 65 | + |
| 66 | + |
| 67 | +class QuickGELUActivation(nn.Layer): |
| 68 | + """ |
| 69 | + Applies GELU approximation that is fast but somewhat inaccurate. See: https://github.com/hendrycks/GELUs |
| 70 | + """ |
| 71 | + |
| 72 | + def forward(self, input: Tensor) -> Tensor: |
| 73 | + return input * F.sigmoid(1.702 * input) |
| 74 | + |
| 75 | + |
| 76 | +class ClippedGELUActivation(nn.Layer): |
| 77 | + """ |
| 78 | + Clip the range of possible GeLU outputs between [min, max]. This is especially useful for quantization purpose, as |
| 79 | + it allows mapping negatives values in the GeLU spectrum. For more information on this trick, please refer to |
| 80 | + https://arxiv.org/abs/2004.09602. |
| 81 | +
|
| 82 | + Gaussian Error Linear Unit. Original Implementation of the gelu activation function in Google Bert repo when |
| 83 | + initially created. |
| 84 | +
|
| 85 | + For information: OpenAI GPT's gelu is slightly different (and gives slightly different results): 0.5 * x * (1 + |
| 86 | + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))). See https://arxiv.org/abs/1606.08415 |
| 87 | + """ |
| 88 | + |
| 89 | + def __init__(self, min: float, max: float): |
| 90 | + if min > max: |
| 91 | + raise ValueError(f"min should be < max (got min: {min}, max: {max})") |
| 92 | + |
| 93 | + super().__init__() |
| 94 | + self.min = min |
| 95 | + self.max = max |
| 96 | + |
| 97 | + def forward(self, x: Tensor) -> Tensor: |
| 98 | + return paddle.clip(gelu(x), self.min, self.max) |
| 99 | + |
| 100 | + |
| 101 | +class SiLUActivation(nn.Layer): |
| 102 | + """ |
| 103 | + See Gaussian Error Linear Units (Hendrycks et al., https://arxiv.org/abs/1606.08415) where the SiLU (Sigmoid Linear |
| 104 | + Unit) was originally introduced and coined, and see Sigmoid-Weighted Linear Units for Neural Network Function |
| 105 | + Approximation in Reinforcement Learning (Elfwing et al., https://arxiv.org/abs/1702.03118) and Swish: a Self-Gated |
| 106 | + Activation Function (Ramachandran et al., https://arxiv.org/abs/1710.05941v1) where the SiLU was experimented with |
| 107 | + later. |
| 108 | + """ |
| 109 | + |
| 110 | + def forward(self, input: Tensor) -> Tensor: |
| 111 | + return F.silu(input) |
| 112 | + |
| 113 | + |
| 114 | +class MishActivation(nn.Layer): |
| 115 | + """ |
| 116 | + See Mish: A Self-Regularized Non-Monotonic Activation Function (Misra., https://arxiv.org/abs/1908.08681). Also |
| 117 | + visit the official repository for the paper: https://github.com/digantamisra98/Mish |
| 118 | + """ |
| 119 | + |
| 120 | + def forward(self, input: Tensor) -> Tensor: |
| 121 | + return F.mish(input) |
| 122 | + |
| 123 | + |
| 124 | +class LinearActivation(nn.Layer): |
| 125 | + """ |
| 126 | + Applies the linear activation function, i.e. forwarding input directly to output. |
| 127 | + """ |
| 128 | + |
| 129 | + def forward(self, input: Tensor) -> Tensor: |
| 130 | + return input |
| 131 | + |
| 132 | + |
| 133 | +class ClassInstantier(OrderedDict): |
| 134 | + def __getitem__(self, key): |
| 135 | + content = super().__getitem__(key) |
| 136 | + cls, kwargs = content if isinstance(content, tuple) else (content, {}) |
| 137 | + return cls(**kwargs) |
| 138 | + |
| 139 | + |
| 140 | +ACT2CLS = { |
| 141 | + "gelu": GELUActivation, |
| 142 | + "gelu_10": (ClippedGELUActivation, {"min": -10, "max": 10}), |
| 143 | + "gelu_fast": FastGELUActivation, |
| 144 | + "gelu_new": NewGELUActivation, |
| 145 | + "gelu_python": (GELUActivation, {"use_gelu_python": True}), |
| 146 | + "linear": LinearActivation, |
| 147 | + "mish": MishActivation, |
| 148 | + "quick_gelu": QuickGELUActivation, |
| 149 | + "relu": nn.ReLU, |
| 150 | + "relu6": nn.ReLU6, |
| 151 | + "sigmoid": nn.Sigmoid, |
| 152 | + "silu": SiLUActivation, |
| 153 | + "swish": SiLUActivation, |
| 154 | + "tanh": nn.Tanh, |
| 155 | +} |
| 156 | +ACT2FN = ClassInstantier(ACT2CLS) |
| 157 | + |
| 158 | + |
| 159 | +def get_activation(activation_string): |
| 160 | + if activation_string in ACT2FN: |
| 161 | + return ACT2FN[activation_string] |
| 162 | + else: |
| 163 | + raise KeyError(f"function {activation_string} not found in ACT2FN mapping {list(ACT2FN.keys())}") |
| 164 | + |
| 165 | + |
| 166 | +# For backwards compatibility with: from activations import gelu_python |
| 167 | +gelu_python = get_activation("gelu_python") |
| 168 | +gelu_new = get_activation("gelu_new") |
| 169 | +gelu = get_activation("gelu") |
| 170 | +gelu_fast = get_activation("gelu_fast") |
| 171 | +quick_gelu = get_activation("quick_gelu") |
| 172 | +silu = get_activation("silu") |
| 173 | +mish = get_activation("mish") |
| 174 | +linear_act = get_activation("linear") |
0 commit comments