Skip to content

Commit 58743c3

Browse files
committed
kernelize gelu.
1 parent 50c0b78 commit 58743c3

File tree

3 files changed

+37
-5
lines changed

3 files changed

+37
-5
lines changed

src/diffusers/models/activations.py

Lines changed: 37 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20-
from ..utils import deprecate
21-
from ..utils.import_utils import is_torch_npu_available, is_torch_version
20+
from ..utils import deprecate, get_logger, is_kernels_available, is_torch_npu_available, is_torch_version
21+
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
2222

2323

24+
logger = get_logger(__name__)
25+
2426
if is_torch_npu_available():
2527
import torch_npu
2628

@@ -31,6 +33,7 @@
3133
"gelu": nn.GELU,
3234
"relu": nn.ReLU,
3335
}
36+
KERNELS_REPO_ID = "kernels-community/activation"
3437

3538

3639
def get_activation(act_fn: str) -> nn.Module:
@@ -90,6 +93,38 @@ def forward(self, hidden_states):
9093
return hidden_states
9194

9295

96+
class CUDAOptimizedGELU(nn.Module):
97+
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
98+
if not torch.cuda.is_available():
99+
raise NotImplementedError(f"{self.__class__.__name__} is implemented only for CUDA devices.")
100+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
101+
raise RuntimeError(
102+
f"{self.__class__.__name__} isn't usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
103+
)
104+
if not is_kernels_available():
105+
raise NotImplementedError(
106+
f"{self.__class__.__name__} requires the `kernels` library to be installed. Install it with `pip install kernels`."
107+
)
108+
109+
from kernels import get_kernel
110+
111+
super().__init__()
112+
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
113+
activations = get_kernel(KERNELS_REPO_ID)
114+
if approximate == "tanh":
115+
self.act = activations.gelu_tanh_and_mul
116+
elif approximate == "none":
117+
self.act = activations.gelu_and_mul
118+
else:
119+
raise NotImplementedError
120+
121+
def forward(self, hidden_states):
122+
hidden_states = self.proj(hidden_states)
123+
out = torch.empty_like(hidden_states)
124+
output = self.act(out, hidden_states)
125+
return output
126+
127+
93128
class GEGLU(nn.Module):
94129
r"""
95130
A [variant](https://huggingface.co/papers/2002.05202) of the gated linear unit activation function.

src/diffusers/models/attention.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020

2121
from ..utils import deprecate, logging
2222
from ..utils.import_utils import is_torch_npu_available, is_torch_xla_available, is_xformers_available
23-
from ..utils.kernels_utils import use_kernel_forward_from_hub
2423
from ..utils.torch_utils import maybe_allow_in_graph
2524
from .activations import GEGLU, GELU, ApproximateGELU, FP32SiLU, LinearActivation, SwiGLU
2625
from .attention_processor import Attention, AttentionProcessor, JointAttnProcessor2_0
@@ -1670,7 +1669,6 @@ def forward(
16701669
return hidden_states
16711670

16721671

1673-
@use_kernel_forward_from_hub("MLP")
16741672
class FeedForward(nn.Module):
16751673
r"""
16761674
A feed-forward layer.

src/diffusers/utils/kernels_utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,6 @@ def _get_fa3_from_hub():
3838
"RMSNorm": {
3939
"cuda": LayerRepository(repo_id="kernels-community/liger_kernels", layer_name="LigerRMSNorm"),
4040
},
41-
"MLP": {"cuda": LayerRepository(repo_id="medmekk/triton-llama-mlp", layer_name="TritonLlamaMLP")},
4241
}
4342

4443
register_kernel_mapping(_KERNEL_MAPPING)

0 commit comments

Comments
 (0)