|
17 | 17 | import torch.nn.functional as F |
18 | 18 | from torch import nn |
19 | 19 |
|
20 | | -from ..utils import deprecate |
21 | | -from ..utils.import_utils import is_torch_npu_available, is_torch_version |
| 20 | +from ..utils import deprecate, get_logger, is_kernels_available, is_torch_npu_available, is_torch_version |
| 21 | +from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS |
22 | 22 |
|
23 | 23 |
|
| 24 | +logger = get_logger(__name__) |
| 25 | + |
24 | 26 | if is_torch_npu_available(): |
25 | 27 | import torch_npu |
26 | 28 |
|
|
31 | 33 | "gelu": nn.GELU, |
32 | 34 | "relu": nn.ReLU, |
33 | 35 | } |
| 36 | +KERNELS_REPO_ID = "kernels-community/activation" |
34 | 37 |
|
35 | 38 |
|
36 | 39 | def get_activation(act_fn: str) -> nn.Module: |
@@ -90,6 +93,38 @@ def forward(self, hidden_states): |
90 | 93 | return hidden_states |
91 | 94 |
|
92 | 95 |
|
| 96 | +class CUDAOptimizedGELU(nn.Module): |
| 97 | + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True): |
| 98 | + if not torch.cuda.is_available(): |
| 99 | + raise NotImplementedError(f"{self.__class__.__name__} is implemented only for CUDA devices.") |
| 100 | + if not DIFFUSERS_ENABLE_HUB_KERNELS: |
| 101 | + raise RuntimeError( |
| 102 | + f"{self.__class__.__name__} isn't usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`." |
| 103 | + ) |
| 104 | + if not is_kernels_available(): |
| 105 | + raise NotImplementedError( |
| 106 | + f"{self.__class__.__name__} requires the `kernels` library to be installed. Install it with `pip install kernels`." |
| 107 | + ) |
| 108 | + |
| 109 | + from kernels import get_kernel |
| 110 | + |
| 111 | + super().__init__() |
| 112 | + self.proj = nn.Linear(dim_in, dim_out, bias=bias) |
| 113 | + activations = get_kernel(KERNELS_REPO_ID) |
| 114 | + if approximate == "tanh": |
| 115 | + self.act = activations.gelu_tanh_and_mul |
| 116 | + elif approximate == "none": |
| 117 | + self.act = activations.gelu_and_mul |
| 118 | + else: |
| 119 | + raise NotImplementedError |
| 120 | + |
| 121 | + def forward(self, hidden_states): |
| 122 | + hidden_states = self.proj(hidden_states) |
| 123 | + out = torch.empty_like(hidden_states) |
| 124 | + output = self.act(out, hidden_states) |
| 125 | + return output |
| 126 | + |
| 127 | + |
93 | 128 | class GEGLU(nn.Module): |
94 | 129 | r""" |
95 | 130 | A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function. |
|
0 commit comments