@@ -340,7 +340,7 @@ def __init__(
340340 self .use_ada_layer_norm_single = norm_type == "ada_norm_single"
341341 self .use_layer_norm = norm_type == "layer_norm"
342342 self .use_ada_layer_norm_continuous = norm_type == "ada_norm_continuous"
343- assert norm_type in [ "layer_norm" , "layer_norm_i2vgen" ]
343+
344344 if norm_type in ("ada_norm" , "ada_norm_zero" ) and num_embeds_ada_norm is None :
345345 raise ValueError (
346346 f"`norm_type` is set to { norm_type } , but `num_embeds_ada_norm` is not defined. Please make sure to"
@@ -359,7 +359,6 @@ def __init__(
359359 self .pos_embed = SinusoidalPositionalEmbedding (dim , max_seq_length = num_positional_embeddings )
360360 else :
361361 self .pos_embed = None
362- assert self .pos_embed == None
363362
364363 # Define 3 blocks. Each block has its own normalization layer.
365364 # 1. Self-Attn
@@ -468,7 +467,6 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
468467 self ._chunk_size = chunk_size
469468 self ._chunk_dim = dim
470469
471- # @xp.trace_me("BasicTransformerBlock")
472470 def forward (
473471 self ,
474472 hidden_states : torch .Tensor ,
@@ -480,42 +478,39 @@ def forward(
480478 class_labels : Optional [torch .LongTensor ] = None ,
481479 added_cond_kwargs : Optional [Dict [str , torch .Tensor ]] = None ,
482480 ) -> torch .Tensor :
483- # import pdb; pdb.set_trace()
484- # if cross_attention_kwargs is not None:
485- # if cross_attention_kwargs.get("scale", None) is not None:
486- # logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
481+ if cross_attention_kwargs is not None :
482+ if cross_attention_kwargs .get ("scale" , None ) is not None :
483+ logger .warning ("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored." )
487484
488485 # Notice that normalization is always applied before the real computation in the following blocks.
489486 # 0. Self-Attention
490- # batch_size = hidden_states.shape[0]
491-
492- # if self.norm_type == "ada_norm":
493- # norm_hidden_states = self.norm1(hidden_states, timestep)
494- # elif self.norm_type == "ada_norm_zero":
495- # norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(
496- # hidden_states, timestep, class_labels, hidden_dtype=hidden_states.dtype
497- # )
498- # elif self.norm_type in ["layer_norm", "layer_norm_i2vgen"]:
499- norm_hidden_states = self .norm1 (hidden_states )
500- # elif self.norm_type == "ada_norm_continuous":
501- # norm_hidden_states = self.norm1(hidden_states, added_cond_kwargs["pooled_text_emb"])
502- # elif self.norm_type == "ada_norm_single":
503- # shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (
504- # self.scale_shift_table[None] + timestep.reshape(batch_size, 6, -1)
505- # ).chunk(6, dim=1)
506- # norm_hidden_states = self.norm1(hidden_states)
507- # norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa
508- # else:
509- # raise ValueError("Incorrect norm used")
510-
511- # if self.pos_embed is not None:
512- # norm_hidden_states = self.pos_embed(norm_hidden_states)
487+ batch_size = hidden_states .shape [0 ]
488+
489+ if self .norm_type == "ada_norm" :
490+ norm_hidden_states = self .norm1 (hidden_states , timestep )
491+ elif self .norm_type == "ada_norm_zero" :
492+ norm_hidden_states , gate_msa , shift_mlp , scale_mlp , gate_mlp = self .norm1 (
493+ hidden_states , timestep , class_labels , hidden_dtype = hidden_states .dtype
494+ )
495+ elif self .norm_type in ["layer_norm" , "layer_norm_i2vgen" ]:
496+ norm_hidden_states = self .norm1 (hidden_states )
497+ elif self .norm_type == "ada_norm_continuous" :
498+ norm_hidden_states = self .norm1 (hidden_states , added_cond_kwargs ["pooled_text_emb" ])
499+ elif self .norm_type == "ada_norm_single" :
500+ shift_msa , scale_msa , gate_msa , shift_mlp , scale_mlp , gate_mlp = (
501+ self .scale_shift_table [None ] + timestep .reshape (batch_size , 6 , - 1 )
502+ ).chunk (6 , dim = 1 )
503+ norm_hidden_states = self .norm1 (hidden_states )
504+ norm_hidden_states = norm_hidden_states * (1 + scale_msa ) + shift_msa
505+ else :
506+ raise ValueError ("Incorrect norm used" )
507+
508+ if self .pos_embed is not None :
509+ norm_hidden_states = self .pos_embed (norm_hidden_states )
513510
514511 # 1. Prepare GLIGEN inputs
515- assert cross_attention_kwargs is None
516- cross_attention_kwargs = {}
517- # cross_attention_kwargs = cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {}
518- # gligen_kwargs = cross_attention_kwargs.pop("gligen", None)
512+ cross_attention_kwargs = cross_attention_kwargs .copy () if cross_attention_kwargs is not None else {}
513+ gligen_kwargs = cross_attention_kwargs .pop ("gligen" , None )
519514
520515 attn_output = self .attn1 (
521516 norm_hidden_states ,
@@ -524,33 +519,36 @@ def forward(
524519 ** cross_attention_kwargs ,
525520 )
526521
527- # if self.norm_type == "ada_norm_zero":
528- # attn_output = gate_msa.unsqueeze(1) * attn_output
529- # elif self.norm_type == "ada_norm_single":
530- # attn_output = gate_msa * attn_output
522+ if self .norm_type == "ada_norm_zero" :
523+ attn_output = gate_msa .unsqueeze (1 ) * attn_output
524+ elif self .norm_type == "ada_norm_single" :
525+ attn_output = gate_msa * attn_output
531526
532527 hidden_states = attn_output + hidden_states
528+ if hidden_states .ndim == 4 :
529+ hidden_states = hidden_states .squeeze (1 )
530+
533531 # 1.2 GLIGEN Control
534- # if gligen_kwargs is not None:
535- # hidden_states = self.fuser(hidden_states, gligen_kwargs["objs"])
532+ if gligen_kwargs is not None :
533+ hidden_states = self .fuser (hidden_states , gligen_kwargs ["objs" ])
536534
537535 # 3. Cross-Attention
538536 if self .attn2 is not None :
539- # if self.norm_type == "ada_norm":
540- # norm_hidden_states = self.norm2(hidden_states, timestep)
541- # elif self.norm_type in ["ada_norm_zero", "layer_norm", "layer_norm_i2vgen"]:
542- norm_hidden_states = self .norm2 (hidden_states )
543- # elif self.norm_type == "ada_norm_single":
544- # # For PixArt norm2 isn't applied here:
545- # # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
546- # norm_hidden_states = hidden_states
547- # elif self.norm_type == "ada_norm_continuous":
548- # norm_hidden_states = self.norm2(hidden_states, added_cond_kwargs["pooled_text_emb"])
549- # else:
550- # raise ValueError("Incorrect norm")
551-
552- # if self.pos_embed is not None and self.norm_type != "ada_norm_single":
553- # norm_hidden_states = self.pos_embed(norm_hidden_states)
537+ if self .norm_type == "ada_norm" :
538+ norm_hidden_states = self .norm2 (hidden_states , timestep )
539+ elif self .norm_type in ["ada_norm_zero" , "layer_norm" , "layer_norm_i2vgen" ]:
540+ norm_hidden_states = self .norm2 (hidden_states )
541+ elif self .norm_type == "ada_norm_single" :
542+ # For PixArt norm2 isn't applied here:
543+ # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L70C1-L76C103
544+ norm_hidden_states = hidden_states
545+ elif self .norm_type == "ada_norm_continuous" :
546+ norm_hidden_states = self .norm2 (hidden_states , added_cond_kwargs ["pooled_text_emb" ])
547+ else :
548+ raise ValueError ("Incorrect norm" )
549+
550+ if self .pos_embed is not None and self .norm_type != "ada_norm_single" :
551+ norm_hidden_states = self .pos_embed (norm_hidden_states )
554552
555553 attn_output = self .attn2 (
556554 norm_hidden_states ,
@@ -562,33 +560,32 @@ def forward(
562560
563561 # 4. Feed-forward
564562 # i2vgen doesn't have this norm 🤷♂️
565- # if self.norm_type == "ada_norm_continuous":
566- # norm_hidden_states = self.norm3(hidden_states, added_cond_kwargs["pooled_text_emb"])
567- # elif not self.norm_type == "ada_norm_single":
568- norm_hidden_states = self .norm3 (hidden_states )
563+ if self .norm_type == "ada_norm_continuous" :
564+ norm_hidden_states = self .norm3 (hidden_states , added_cond_kwargs ["pooled_text_emb" ])
565+ elif not self .norm_type == "ada_norm_single" :
566+ norm_hidden_states = self .norm3 (hidden_states )
569567
570- # if self.norm_type == "ada_norm_zero":
571- # norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
568+ if self .norm_type == "ada_norm_zero" :
569+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp [:, None ]) + shift_mlp [:, None ]
572570
573- # if self.norm_type == "ada_norm_single":
574- # norm_hidden_states = self.norm2(hidden_states)
575- # norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp
571+ if self .norm_type == "ada_norm_single" :
572+ norm_hidden_states = self .norm2 (hidden_states )
573+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp ) + shift_mlp
576574
577- assert self ._chunk_size == None
578- # if self._chunk_size is not None:
579- # # "feed_forward_chunk_size" can be used to save memory
580- # ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
581- # else:
582- ff_output = self .ff (norm_hidden_states )
575+ if self ._chunk_size is not None :
576+ # "feed_forward_chunk_size" can be used to save memory
577+ ff_output = _chunked_feed_forward (self .ff , norm_hidden_states , self ._chunk_dim , self ._chunk_size )
578+ else :
579+ ff_output = self .ff (norm_hidden_states )
583580
584- # if self.norm_type == "ada_norm_zero":
585- # ff_output = gate_mlp.unsqueeze(1) * ff_output
586- # elif self.norm_type == "ada_norm_single":
587- # ff_output = gate_mlp * ff_output
581+ if self .norm_type == "ada_norm_zero" :
582+ ff_output = gate_mlp .unsqueeze (1 ) * ff_output
583+ elif self .norm_type == "ada_norm_single" :
584+ ff_output = gate_mlp * ff_output
588585
589586 hidden_states = ff_output + hidden_states
590- # if hidden_states.ndim == 4:
591- # hidden_states = hidden_states.squeeze(1)
587+ if hidden_states .ndim == 4 :
588+ hidden_states = hidden_states .squeeze (1 )
592589
593590 return hidden_states
594591
0 commit comments