Skip to content

Commit 3efe355

Browse files
yiyixuxuyiyixuxu
authored andcommitted
add self.use_ada_layer_norm_* params back to BasicTransformerBlock (#6841)
fix sd reference community ppeline Co-authored-by: yiyixuxu <yixu310@gmail,com>
1 parent 08e6558 commit 3efe355

File tree

3 files changed

+14
-4
lines changed

3 files changed

+14
-4
lines changed

examples/community/stable_diffusion_reference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,7 @@ def hack_CrossAttnDownBlock2D_forward(
538538

539539
return hidden_states, output_states
540540

541-
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
541+
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
542542
eps = 1e-6
543543

544544
output_states = ()
@@ -634,7 +634,9 @@ def hacked_CrossAttnUpBlock2D_forward(
634634

635635
return hidden_states
636636

637-
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
637+
def hacked_UpBlock2D_forward(
638+
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
639+
):
638640
eps = 1e-6
639641
for i, resnet in enumerate(self.resnets):
640642
# pop res hidden states

examples/community/stable_diffusion_xl_reference.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,7 @@ def hack_CrossAttnDownBlock2D_forward(
507507

508508
return hidden_states, output_states
509509

510-
def hacked_DownBlock2D_forward(self, hidden_states, temb=None):
510+
def hacked_DownBlock2D_forward(self, hidden_states, temb=None, **kwargs):
511511
eps = 1e-6
512512

513513
output_states = ()
@@ -603,7 +603,9 @@ def hacked_CrossAttnUpBlock2D_forward(
603603

604604
return hidden_states
605605

606-
def hacked_UpBlock2D_forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
606+
def hacked_UpBlock2D_forward(
607+
self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, **kwargs
608+
):
607609
eps = 1e-6
608610
for i, resnet in enumerate(self.resnets):
609611
# pop res hidden states

src/diffusers/models/attention.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,12 @@ def __init__(
158158
super().__init__()
159159
self.only_cross_attention = only_cross_attention
160160

161+
self.use_ada_layer_norm_zero = (num_embeds_ada_norm is not None) and norm_type == "ada_norm_zero"
162+
self.use_ada_layer_norm = (num_embeds_ada_norm is not None) and norm_type == "ada_norm"
163+
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
164+
self.use_layer_norm = norm_type == "layer_norm"
165+
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
166+
161167
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
162168
raise ValueError(
163169
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"

0 commit comments

Comments
 (0)