Skip to content

Commit 04e9323

Browse files
committed
up
1 parent 9a09162 commit 04e9323

File tree

3 files changed

+11
-10
lines changed

3 files changed

+11
-10
lines changed

src/diffusers/models/activations.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,12 @@ def __init__(self, *args, **kwargs):
100100

101101
activation = get_kernel("kernels-community/activation", revision="add_more_act")
102102
approximate = kwargs.get("approximate", "none")
103+
104+
super().__init__(*args, **kwargs)
103105
if approximate == "none":
104-
self.act_fn = activation.gelu
106+
self.act_fn = activation.layers.Gelu()
105107
elif approximate == "tanh":
106-
self.act_fn = activation.gelu_tanh
107-
super().__init__(*args, **kwargs)
108+
self.act_fn = activation.layers.GeluTanh()
108109

109110
def forward(self, hidden_states):
110111
hidden_states = self.proj(hidden_states)

src/diffusers/models/normalization.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from kernels import get_kernel
3232

3333
activation = get_kernel("kernels-community/activation", revision="add_more_act")
34-
silu_kernel = activation.silu
34+
silu_kernel = activation.layers.Silu
3535

3636

3737
class AdaLayerNorm(nn.Module):
@@ -69,7 +69,7 @@ def __init__(
6969
if not DIFFUSERS_ENABLE_HUB_KERNELS:
7070
self.silu = nn.SiLU()
7171
else:
72-
self.silu = silu_kernel
72+
self.silu = silu_kernel()
7373
self.linear = nn.Linear(embedding_dim, output_dim)
7474
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
7575

@@ -159,7 +159,7 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, nor
159159
if not DIFFUSERS_ENABLE_HUB_KERNELS:
160160
self.silu = nn.SiLU()
161161
else:
162-
self.silu = silu_kernel
162+
self.silu = silu_kernel()
163163
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
164164
if norm_type == "layer_norm":
165165
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -201,7 +201,7 @@ def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
201201
if not DIFFUSERS_ENABLE_HUB_KERNELS:
202202
self.silu = nn.SiLU()
203203
else:
204-
self.silu = silu_kernel
204+
self.silu = silu_kernel()
205205
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
206206
if norm_type == "layer_norm":
207207
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -356,7 +356,7 @@ def __init__(
356356
if not DIFFUSERS_ENABLE_HUB_KERNELS:
357357
self.silu = nn.SiLU()
358358
else:
359-
self.silu = silu_kernel
359+
self.silu = silu_kernel()
360360
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
361361
if norm_type == "layer_norm":
362362
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
from kernels import get_kernel
4646

4747
activation = get_kernel("kernels-community/activation", revision="add_more_act")
48-
gelu_tanh_kernel = activation.gelu_tanh
48+
gelu_tanh_kernel = activation.layers.GeluTanh
4949

5050

5151
def _get_projections(attn: "FluxAttention", hidden_states, encoder_hidden_states=None):
@@ -360,7 +360,7 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
360360
if not DIFFUSERS_ENABLE_HUB_KERNELS:
361361
self.act_mlp = nn.GELU(approximate="tanh")
362362
else:
363-
self.act_mlp = gelu_tanh_kernel
363+
self.act_mlp = gelu_tanh_kernel()
364364

365365
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
366366

0 commit comments

Comments
 (0)