Skip to content

Commit 0b95ec0

Browse files
committed
Restore BasicTransformerBlock to original huggingface implementation
1 parent e7a3cfd commit 0b95ec0

File tree

5 files changed

+202
-120
lines changed

5 files changed

+202
-120
lines changed

examples/research_projects/pytorch_xla/training/text_to_image/train_text_to_image_sdxl.py

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -549,10 +549,6 @@ def encode_prompt(batch, text_encoders, tokenizers, proportion_empty_prompts, ca
549549

550550

551551
prompt_embeds = torch.concat(prompt_embeds_list, dim=-1).to(dtype=dtype)
552-
print(prompt_embeds.shape)
553-
p3d = (0,0, 0, 128-77)
554-
prompt_embeds = F.pad(prompt_embeds, p3d, "constant", 0)
555-
print(prompt_embeds.shape)
556552
pooled_prompt_embeds = pooled_prompt_embeds.view(bs_embed, -1).to(dtype=dtype)
557553
return {"prompt_embeds": prompt_embeds, "pooled_prompt_embeds": pooled_prompt_embeds}
558554

@@ -667,11 +663,7 @@ def main(args):
667663
revision=args.revision,
668664
use_fast=False
669665
)
670-
671-
# from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
672-
673-
# unet = apply_xla_patch_to_nn_linear(unet, xs.xla_patched_nn_linear_forward)
674-
unet.enable_xla_flash_attention(partition_spec=("data", None, None, None))
666+
unet.enable_xla_attention(partition_spec=("data", None, None, None))
675667

676668
vae.requires_grad_(False)
677669
text_encoder.requires_grad_(False)
@@ -758,11 +750,14 @@ def preprocess_train(examples):
758750
from datasets.fingerprint import Hasher
759751
# import pdb; pdb.set_trace()
760752
old_batch_size = args.train_batch_size
761-
args.train_batch_size=21
753+
old_arg = args.output_dir
754+
args.output_dir = '/tmp/trained-model/'
755+
args.train_batch_size=22
762756
new_fingerprint = Hasher.hash(args)
763757
args.train_batch_size=64
764758
new_fingerprint_for_vae = Hasher.hash((args.pretrained_model_name_or_path, args))
765759
args.train_batch_size=old_batch_size
760+
args.output_dir = old_arg
766761
train_dataset_with_embeddings = train_dataset.map(
767762
compute_embeddings_fn, batched=True, batch_size=50, new_fingerprint=new_fingerprint
768763
)
@@ -829,7 +824,6 @@ def collate_fn(examples):
829824
print(f" Total optimization steps = {args.max_train_steps}")
830825

831826
# unet = add_checkpoints(unet)
832-
# import pdb; pdb.set_trace()
833827
trainer = TrainSD(
834828
weight_dtype=weight_dtype,
835829
device=device,
@@ -840,19 +834,15 @@ def collate_fn(examples):
840834
args=args,
841835
)
842836
trainer.start_training()
843-
# unet = trainer.unet.to("cpu")
844-
# vae = trainer.vae.to("cpu")
845-
# text_encoder = trainer.text_encoder.to("cpu")
846-
847-
# pipeline = StableDiffusionXLPipeline.from_pretrained(
848-
# args.pretrained_model_name_or_path,
849-
# text_encoder=text_encoder,
850-
# vae=vae,
851-
# unet=unet,
852-
# revision=args.revision,
853-
# variant=args.variant,
854-
# )
855-
# pipeline.save_pretrained(args.output_dir)
837+
unet = trainer.unet.to("cpu")
838+
839+
pipeline = StableDiffusionXLPipeline.from_pretrained(
840+
args.pretrained_model_name_or_path,
841+
unet=unet,
842+
revision=args.revision,
843+
variant=args.variant,
844+
)
845+
pipeline.save_pretrained(args.output_dir)
856846

857847
# if xm.is_master_ordinal() and args.push_to_hub:
858848
# save_model_card(args, repo_id, repo_folder=args.output_dir)

src/diffusers/models/attention.py

Lines changed: 73 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def __init__(
340340
self.use_ada_layer_norm_single = norm_type == "ada_norm_single"
341341
self.use_layer_norm = norm_type == "layer_norm"
342342
self.use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
343-
assert norm_type in ["layer_norm", "layer_norm_i2vgen"]
343+
344344
if norm_type in ("ada_norm", "ada_norm_zero") and num_embeds_ada_norm is None:
345345
raise ValueError(
346346
f"`norm_type` is set to {norm_type}, but `num_embeds_ada_norm` is not defined. Please make sure to"
@@ -359,7 +359,6 @@ def __init__(
359359
self.pos_embed = SinusoidalPositionalEmbedding(dim, max_seq_length=num_positional_embeddings)
360360
else:
361361
self.pos_embed = None
362-
assert self.pos_embed == None
363362

364363
# Define 3 blocks. Each block has its own normalization layer.
365364
# 1. Self-Attn
@@ -468,7 +467,6 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
468467
self._chunk_size = chunk_size
469468
self._chunk_dim = dim
470469

471-
# @xp.trace_me("BasicTransformerBlock")
472470
def forward(
473471
self,
474472
hidden_states: torch.Tensor,
@@ -480,42 +478,39 @@ def forward(
480478
class_labels: Optional[torch.LongTensor] = None,
481479
added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
482480
) -> torch.Tensor:
483-
# import pdb; pdb.set_trace()
484-
# if cross_attention_kwargs is not None:
485-
# if cross_attention_kwargs.get("scale", None) is not None:
486-
# logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
481+
if cross_attention_kwargs is not None:
482+
if cross_attention_kwargs.get("scale", None) is not None:
483+
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
487484

488485
# Notice that normalization is always applied before the real computation in the following blocks.
489486
# 0. Self-Attention
490-
# batch_size = hidden_states.shape[0]
491-
492-
# if self.norm_type == "ada_norm":
493-
# norm_hidden_states = self.norm1(hidden_states, timestep)
494-
# elif self.norm_type == "ada_norm_zero":
495-
# norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
496-
# hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
497-
# )
498-
# elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
499-
norm_hidden_states = self.norm1(hidden_states)
500-
# elif self.norm_type == "ada_norm_continuous":
501-
# norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
502-
# elif self.norm_type == "ada_norm_single":
503-
# shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
504-
# self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
505-
# ).chunk(6, dim=1)
506-
# norm_hidden_states = self.norm1(hidden_states)
507-
# norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
508-
# else:
509-
# raise ValueError("Incorrect norm used")
510-
511-
# if self.pos_embed is not None:
512-
# norm_hidden_states = self.pos_embed(norm_hidden_states)
487+
batch_size = hidden_states.shape[0]
488+
489+
if self.norm_type == "ada_norm":
490+
norm_hidden_states = self.norm1(hidden_states, timestep)
491+
elif self.norm_type == "ada_norm_zero":
492+
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
493+
hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
494+
)
495+
elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
496+
norm_hidden_states = self.norm1(hidden_states)
497+
elif self.norm_type == "ada_norm_continuous":
498+
norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
499+
elif self.norm_type == "ada_norm_single":
500+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
501+
self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
502+
).chunk(6, dim=1)
503+
norm_hidden_states = self.norm1(hidden_states)
504+
norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
505+
else:
506+
raise ValueError("Incorrect norm used")
507+
508+
if self.pos_embed is not None:
509+
norm_hidden_states = self.pos_embed(norm_hidden_states)
513510

514511
# 1. Prepare GLIGEN inputs
515-
assert cross_attention_kwargs is None
516-
cross_attention_kwargs = {}
517-
# cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
518-
# gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
512+
cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
513+
gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
519514

520515
attn_output = self.attn1(
521516
norm_hidden_states,
@@ -524,33 +519,36 @@ def forward(
524519
**cross_attention_kwargs,
525520
)
526521

527-
# if self.norm_type == "ada_norm_zero":
528-
# attn_output = gate_msa.unsqueeze(1) * attn_output
529-
# elif self.norm_type == "ada_norm_single":
530-
# attn_output = gate_msa * attn_output
522+
if self.norm_type == "ada_norm_zero":
523+
attn_output = gate_msa.unsqueeze(1) * attn_output
524+
elif self.norm_type == "ada_norm_single":
525+
attn_output = gate_msa * attn_output
531526

532527
hidden_states = attn_output + hidden_states
528+
if hidden_states.ndim == 4:
529+
hidden_states = hidden_states.squeeze(1)
530+
533531
# 1.2 GLIGEN Control
534-
# if gligen_kwargs is not None:
535-
# hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
532+
if gligen_kwargs is not None:
533+
hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
536534

537535
# 3. Cross-Attention
538536
if self.attn2 is not None:
539-
# if self.norm_type == "ada_norm":
540-
# norm_hidden_states = self.norm2(hidden_states, timestep)
541-
# elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
542-
norm_hidden_states = self.norm2(hidden_states)
543-
# elif self.norm_type == "ada_norm_single":
544-
# # For PixArt norm2 isn't applied here:
545-
# # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
546-
# norm_hidden_states = hidden_states
547-
# elif self.norm_type == "ada_norm_continuous":
548-
# norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
549-
# else:
550-
# raise ValueError("Incorrect norm")
551-
552-
# if self.pos_embed is not None and self.norm_type != "ada_norm_single":
553-
# norm_hidden_states = self.pos_embed(norm_hidden_states)
537+
if self.norm_type == "ada_norm":
538+
norm_hidden_states = self.norm2(hidden_states, timestep)
539+
elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
540+
norm_hidden_states = self.norm2(hidden_states)
541+
elif self.norm_type == "ada_norm_single":
542+
# For PixArt norm2 isn't applied here:
543+
# https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
544+
norm_hidden_states = hidden_states
545+
elif self.norm_type == "ada_norm_continuous":
546+
norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
547+
else:
548+
raise ValueError("Incorrect norm")
549+
550+
if self.pos_embed is not None and self.norm_type != "ada_norm_single":
551+
norm_hidden_states = self.pos_embed(norm_hidden_states)
554552

555553
attn_output = self.attn2(
556554
norm_hidden_states,
@@ -562,33 +560,32 @@ def forward(
562560

563561
# 4. Feed-forward
564562
# i2vgen doesn't have this norm 🤷‍♂️
565-
# if self.norm_type == "ada_norm_continuous":
566-
# norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
567-
# elif not self.norm_type == "ada_norm_single":
568-
norm_hidden_states = self.norm3(hidden_states)
563+
if self.norm_type == "ada_norm_continuous":
564+
norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
565+
elif not self.norm_type == "ada_norm_single":
566+
norm_hidden_states = self.norm3(hidden_states)
569567

570-
# if self.norm_type == "ada_norm_zero":
571-
# norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
568+
if self.norm_type == "ada_norm_zero":
569+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
572570

573-
# if self.norm_type == "ada_norm_single":
574-
# norm_hidden_states = self.norm2(hidden_states)
575-
# norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
571+
if self.norm_type == "ada_norm_single":
572+
norm_hidden_states = self.norm2(hidden_states)
573+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
576574

577-
assert self._chunk_size == None
578-
# if self._chunk_size is not None:
579-
# # "feed_forward_chunk_size" can be used to save memory
580-
# ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
581-
# else:
582-
ff_output = self.ff(norm_hidden_states)
575+
if self._chunk_size is not None:
576+
# "feed_forward_chunk_size" can be used to save memory
577+
ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
578+
else:
579+
ff_output = self.ff(norm_hidden_states)
583580

584-
# if self.norm_type == "ada_norm_zero":
585-
# ff_output = gate_mlp.unsqueeze(1) * ff_output
586-
# elif self.norm_type == "ada_norm_single":
587-
# ff_output = gate_mlp * ff_output
581+
if self.norm_type == "ada_norm_zero":
582+
ff_output = gate_mlp.unsqueeze(1) * ff_output
583+
elif self.norm_type == "ada_norm_single":
584+
ff_output = gate_mlp * ff_output
588585

589586
hidden_states = ff_output + hidden_states
590-
# if hidden_states.ndim == 4:
591-
# hidden_states = hidden_states.squeeze(1)
587+
if hidden_states.ndim == 4:
588+
hidden_states = hidden_states.squeeze(1)
592589

593590
return hidden_states
594591

0 commit comments

Comments
 (0)