Skip to content

Commit 33a8a3b

Browse files
committed
up
1 parent 58743c3 commit 33a8a3b

File tree

2 files changed

+39
-32
lines changed

2 files changed

+39
-32
lines changed

src/diffusers/models/activations.py

Lines changed: 14 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@
1717
import torch.nn.functional as F
1818
from torch import nn
1919

20-
from ..utils import deprecate, get_logger, is_kernels_available, is_torch_npu_available, is_torch_version
21-
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
20+
from ..utils import deprecate, get_logger, is_torch_npu_available, is_torch_version
2221

2322

2423
logger = get_logger(__name__)
@@ -93,36 +92,24 @@ def forward(self, hidden_states):
9392
return hidden_states
9493

9594

96-
class CUDAOptimizedGELU(nn.Module):
97-
def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", bias: bool = True):
98-
if not torch.cuda.is_available():
99-
raise NotImplementedError(f"{self.__class__.__name__} is implemented only for CUDA devices.")
100-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
101-
raise RuntimeError(
102-
f"{self.__class__.__name__} isn't usable because the `DIFFUSERS_ENABLE_HUB_KERNELS` env var isn't set. Please set it like `export DIFFUSERS_ENABLE_HUB_KERNELS=yes`."
103-
)
104-
if not is_kernels_available():
105-
raise NotImplementedError(
106-
f"{self.__class__.__name__} requires the `kernels` library to be installed. Install it with `pip install kernels`."
107-
)
108-
95+
# TODO: validation checks / consider making Python classes of activations like `transformers`
96+
# All of these are temporary for now.
97+
class CUDAOptimizedGELU(GELU):
98+
def __init__(self, *args, **kwargs):
10999
from kernels import get_kernel
110100

111-
super().__init__()
112-
self.proj = nn.Linear(dim_in, dim_out, bias=bias)
113-
activations = get_kernel(KERNELS_REPO_ID)
114-
if approximate == "tanh":
115-
self.act = activations.gelu_tanh_and_mul
116-
elif approximate == "none":
117-
self.act = activations.gelu_and_mul
118-
else:
119-
raise NotImplementedError
101+
activation = get_kernel("kernels-community/activation", revision="add_more_act")
102+
approximate = kwargs.get("approximate", "none")
103+
if approximate == "none":
104+
self.act_fn = activation.gelu
105+
elif approximate == "tanh":
106+
self.act_fn = activation.gelu_tanh
107+
super().__init__(*args, **kwargs)
120108

121109
def forward(self, hidden_states):
122110
hidden_states = self.proj(hidden_states)
123-
out = torch.empty_like(hidden_states)
124-
output = self.act(out, hidden_states)
125-
return output
111+
hidden_states = self.act_fn(hidden_states)
112+
return hidden_states
126113

127114

128115
class GEGLU(nn.Module):

src/diffusers/models/normalization.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,20 @@
2020
import torch.nn as nn
2121
import torch.nn.functional as F
2222

23-
from ..utils import is_torch_npu_available, is_torch_version
23+
from ..utils import is_kernels_available, is_torch_npu_available, is_torch_version
24+
from ..utils.constants import DIFFUSERS_ENABLE_HUB_KERNELS
2425
from ..utils.kernels_utils import use_kernel_forward_from_hub
2526
from .activations import get_activation
2627
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
2728

2829

30+
if is_kernels_available() and DIFFUSERS_ENABLE_HUB_KERNELS:
31+
from kernels import get_kernel
32+
33+
activation = get_kernel("kernels-community/activation", revision="add_more_act")
34+
silu_kernel = activation.silu
35+
36+
2937
class AdaLayerNorm(nn.Module):
3038
r"""
3139
Norm layer modified to incorporate timestep embeddings.
@@ -58,7 +66,10 @@ def __init__(
5866
else:
5967
self.emb = None
6068

61-
self.silu = nn.SiLU()
69+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
70+
self.silu = nn.SiLU()
71+
else:
72+
self.silu = silu_kernel
6273
self.linear = nn.Linear(embedding_dim, output_dim)
6374
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
6475

@@ -145,7 +156,10 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, nor
145156
else:
146157
self.emb = None
147158

148-
self.silu = nn.SiLU()
159+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
160+
self.silu = nn.SiLU()
161+
else:
162+
self.silu = silu_kernel
149163
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
150164
if norm_type == "layer_norm":
151165
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -184,7 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
184198
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
185199
super().__init__()
186200

187-
self.silu = nn.SiLU()
201+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
202+
self.silu = nn.SiLU()
203+
else:
204+
self.silu = silu_kernel
188205
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
189206
if norm_type == "layer_norm":
190207
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -336,7 +353,10 @@ def __init__(
336353
norm_type="layer_norm",
337354
):
338355
super().__init__()
339-
self.silu = nn.SiLU()
356+
if not DIFFUSERS_ENABLE_HUB_KERNELS:
357+
self.silu = nn.SiLU()
358+
else:
359+
self.silu = silu_kernel
340360
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
341361
if norm_type == "layer_norm":
342362
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)

0 commit comments

Comments
 (0)