Skip to content

Commit 92199ff

Browse files
committed
up
1 parent 04e9323 commit 92199ff

File tree

2 files changed

+33
-20
lines changed

2 files changed

+33
-20
lines changed

src/diffusers/models/normalization.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -66,10 +66,10 @@ def __init__(
6666
else:
6767
self.emb = None
6868

69-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
70-
self.silu = nn.SiLU()
71-
else:
69+
if DIFFUSERS_ENABLE_HUB_KERNELS:
7270
self.silu = silu_kernel()
71+
else:
72+
self.silu = nn.SiLU()
7373
self.linear = nn.Linear(embedding_dim, output_dim)
7474
self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
7575

@@ -156,10 +156,10 @@ def __init__(self, embedding_dim: int, num_embeddings: Optional[int] = None, nor
156156
else:
157157
self.emb = None
158158

159-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
160-
self.silu = nn.SiLU()
161-
else:
159+
if DIFFUSERS_ENABLE_HUB_KERNELS:
162160
self.silu = silu_kernel()
161+
else:
162+
self.silu = nn.SiLU()
163163
self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=bias)
164164
if norm_type == "layer_norm":
165165
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -198,10 +198,10 @@ class AdaLayerNormZeroSingle(nn.Module):
198198
def __init__(self, embedding_dim: int, norm_type="layer_norm", bias=True):
199199
super().__init__()
200200

201-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
202-
self.silu = nn.SiLU()
203-
else:
201+
if DIFFUSERS_ENABLE_HUB_KERNELS:
204202
self.silu = silu_kernel()
203+
else:
204+
self.silu = nn.SiLU()
205205
self.linear = nn.Linear(embedding_dim, 3 * embedding_dim, bias=bias)
206206
if norm_type == "layer_norm":
207207
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
@@ -353,10 +353,10 @@ def __init__(
353353
norm_type="layer_norm",
354354
):
355355
super().__init__()
356-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
357-
self.silu = nn.SiLU()
358-
else:
356+
if DIFFUSERS_ENABLE_HUB_KERNELS:
359357
self.silu = silu_kernel()
358+
else:
359+
self.silu = nn.SiLU()
360360
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
361361
if norm_type == "layer_norm":
362362
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)

src/diffusers/models/transformers/transformer_flux.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,14 @@ def __init__(
307307
self.added_kv_proj_dim = added_kv_proj_dim
308308
self.added_proj_bias = added_proj_bias
309309

310-
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
311-
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
310+
if DIFFUSERS_ENABLE_HUB_KERNELS:
311+
from ..normalization import RMSNorm
312+
313+
self.norm_q = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
314+
self.norm_k = RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
315+
else:
316+
self.norm_q = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
317+
self.norm_k = torch.nn.RMSNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
312318
self.to_q = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
313319
self.to_k = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
314320
self.to_v = torch.nn.Linear(query_dim, self.inner_dim, bias=bias)
@@ -319,8 +325,14 @@ def __init__(
319325
self.to_out.append(torch.nn.Dropout(dropout))
320326

321327
if added_kv_proj_dim is not None:
322-
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
323-
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
328+
if DIFFUSERS_ENABLE_HUB_KERNELS:
329+
from ..normalization import RMSNorm
330+
331+
self.norm_added_q = RMSNorm(dim_head, eps=eps)
332+
self.norm_added_k = RMSNorm(dim_head, eps=eps)
333+
else:
334+
self.norm_added_q = torch.nn.RMSNorm(dim_head, eps=eps)
335+
self.norm_added_k = torch.nn.RMSNorm(dim_head, eps=eps)
324336
self.add_q_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
325337
self.add_k_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
326338
self.add_v_proj = torch.nn.Linear(added_kv_proj_dim, self.inner_dim, bias=added_proj_bias)
@@ -357,10 +369,11 @@ def __init__(self, dim: int, num_attention_heads: int, attention_head_dim: int,
357369

358370
self.norm = AdaLayerNormZeroSingle(dim)
359371
self.proj_mlp = nn.Linear(dim, self.mlp_hidden_dim)
360-
if not DIFFUSERS_ENABLE_HUB_KERNELS:
361-
self.act_mlp = nn.GELU(approximate="tanh")
362-
else:
363-
self.act_mlp = gelu_tanh_kernel()
372+
self.act_mlp = nn.GELU(approximate="tanh")
373+
# if not DIFFUSERS_ENABLE_HUB_KERNELS:
374+
# self.act_mlp = nn.GELU(approximate="tanh")
375+
# else:
376+
# self.act_mlp = gelu_tanh_kernel()
364377

365378
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
366379

0 commit comments

Comments
 (0)