Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/diffusers/models/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ def __init__(self, embedding_dim: int, num_embeddings: int):

def forward(self, x: torch.Tensor, timestep: torch.Tensor) -> torch.Tensor:
emb = self.linear(self.silu(self.emb(timestep)))
scale, shift = torch.chunk(emb, 2)
x = self.norm(x) * (1 + scale) + shift
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale[:, None]) + shift[:, None]
return x


Expand Down
13 changes: 9 additions & 4 deletions src/diffusers/models/transformers/transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,10 +492,8 @@ def custom_forward(*inputs):
output = F.log_softmax(logits.double(), dim=1).float()

if self.is_input_patches:
if self.config.norm_type != "ada_norm_single":
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
if self.config.norm_type == "ada_norm":
conditioning = self.transformer_blocks[0].norm1.emb(timestep)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
Expand All @@ -506,6 +504,13 @@ def custom_forward(*inputs):
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.squeeze(1)
else:
conditioning = self.transformer_blocks[0].norm1.emb(
timestep, class_labels, hidden_dtype=hidden_states.dtype
)
shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
hidden_states = self.proj_out_2(hidden_states)
Comment on lines +507 to +513
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't this practically the same as what we're doing when norm_type == "ada_norm"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is similar, but ada_norm doesn't take class labels as an argument. I moved this to else because the original was if norm_type != ada_norm_single. From what I can tell that block still only supports the norm used in the original DiT implementation. It might be worth it to refactor and allow other norm types.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can then still condition on if class_labels is not None or something like that no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If class_labels is None then you can't use AdaLayerNormZero. I'm not sure what the default norm should be when you want to condition on text or audio, but I picked AdaLayerNorm because it was similar to the zero variant without needing class labels.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about ada norm single i.e., the one used in PixArt Alpha?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you use that one without additional arguments you get this error:

TypeError: PixArtAlphaCombinedTimestepSizeEmbeddings(
  (time_proj): Timesteps()
  (timestep_embedder): TimestepEmbedding(
    (linear_1): Linear(in_features=256, out_features=1408, bias=True)
    (act): SiLU()
    (linear_2): Linear(in_features=1408, out_features=1408, bias=True)
  )
) argument after ** must be a mapping, not NoneType

It requires the arugments resolution and aspect_ratio, but those could have a default of None because they aren't required when use_additional_conditions=False

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think changing that block would be more appropriate?

@yiyixuxu WDYT?


# unpatchify
if self.adaln_single is None:
Expand Down
4 changes: 2 additions & 2 deletions tests/models/test_layers_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,8 @@ def test_spatial_transformer_timestep(self):
num_embeds_ada_norm=num_embeds_ada_norm,
).to(torch_device)
with torch.no_grad():
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)
timestep_1 = torch.tensor(1, dtype=torch.long).to(torch_device)[None]
timestep_2 = torch.tensor(2, dtype=torch.long).to(torch_device)[None]
attention_scores_1 = spatial_transformer_block(sample, timestep=timestep_1).sample
attention_scores_2 = spatial_transformer_block(sample, timestep=timestep_2).sample

Expand Down