From d07a24c5d4d0d5bac93a75a10b1c15ce8252dd90 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 4 Apr 2024 17:46:09 -0400 Subject: [PATCH 1/3] Fix Transformer2DModel ada_norm --- src/diffusers/models/normalization.py | 4 ++-- .../models/transformers/transformer_2d.py | 14 ++++++++++---- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index 036a66890e67..c9f75a3dcfb2 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -43,8 +43,8 @@ def __init__(self, embedding_dim: int, num_embeddings: int): def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: emb = self.linear(self.silu(self.emb(timestep))) - scale, shift = torch.chunk(emb, 2) - x = self.norm(x) * (1 + scale) + shift + scale, shift = torch.chunk(emb, 2, dim=1) + x = self.norm(x) * (1 + scale[:, None, :]) + shift[:, None, :] return x diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index a9782256757b..fbb57e9cb51c 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -492,10 +492,8 @@ def custom_forward(*inputs): output = F.log_softmax(logits.double(), dim=1).float() if self.is_input_patches: - if self.config.norm_type != "ada_norm_single": - conditioning = self.transformer_blocks[0].norm1.emb( - timestep, class_labels, hidden_dtype=hidden_states.dtype - ) + if self.config.norm_type == "ada_norm": + conditioning = self.transformer_blocks[0].norm1.emb(timestep) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) @@ -506,6 +504,14 @@ def custom_forward(*inputs): hidden_states = hidden_states * (1 + scale) + shift hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.squeeze(1) + else: + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * ( + 1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) # unpatchify if self.adaln_single is None: From 78c3d9ff82cbacd85f3ba69ff353c458cc231721 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 4 Apr 2024 19:51:49 -0400 Subject: [PATCH 2/3] make style --- src/diffusers/models/transformers/transformer_2d.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/diffusers/models/transformers/transformer_2d.py b/src/diffusers/models/transformers/transformer_2d.py index fbb57e9cb51c..b0ea809f4363 100644 --- a/src/diffusers/models/transformers/transformer_2d.py +++ b/src/diffusers/models/transformers/transformer_2d.py @@ -509,8 +509,7 @@ def custom_forward(*inputs): timestep, class_labels, hidden_dtype=hidden_states.dtype ) shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) - hidden_states = self.norm_out(hidden_states) * ( - 1 + scale[:, None]) + shift[:, None] + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] hidden_states = self.proj_out_2(hidden_states) # unpatchify From 1c1d00e28c2f2e90d807f95a1594eefcf73cdeb1 Mon Sep 17 00:00:00 2001 From: Will Date: Fri, 5 Apr 2024 09:19:36 -0400 Subject: [PATCH 3/3] fix spatial transformer test --- src/diffusers/models/normalization.py | 2 +- tests/models/test_layers_utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diffusers/models/normalization.py b/src/diffusers/models/normalization.py index c9f75a3dcfb2..6e76518fe24b 100644 --- a/src/diffusers/models/normalization.py +++ b/src/diffusers/models/normalization.py @@ -44,7 +44,7 @@ def __init__(self, embedding_dim: int, num_embeddings: int): def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor: emb = self.linear(self.silu(self.emb(timestep))) scale, shift = torch.chunk(emb, 2, dim=1) - x = self.norm(x) * (1 + scale[:, None, :]) + shift[:, None, :] + x = self.norm(x) * (1 + scale[:, None]) + shift[:, None] return x diff --git a/tests/models/test_layers_utils.py b/tests/models/test_layers_utils.py index b5a5bec471a6..3f13634e9513 100644 --- a/tests/models/test_layers_utils.py +++ b/tests/models/test_layers_utils.py @@ -378,8 +378,8 @@ def test_spatial_transformer_timestep(self): num_embeds_ada_norm=num_embeds_ada_norm, ).to(torch_device) with torch.no_grad(): - timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device) - timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device) + timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)[None] + timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)[None] attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample