88from ...loaders import PeftAdapterMixin
99from ...models .modeling_outputs import Transformer2DModelOutput
1010from ...models .modeling_utils import ModelMixin
11- from ...utils import USE_PEFT_BACKEND , logging , scale_lora_layers , unscale_lora_layers
11+ from ...utils import USE_PEFT_BACKEND , deprecate , logging , scale_lora_layers , unscale_lora_layers
1212from ...utils .torch_utils import maybe_allow_in_graph
1313from ..attention import Attention
1414from ..embeddings import TimestepEmbedding , Timesteps
@@ -686,46 +686,108 @@ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_train
686686 x = torch .cat (x_arr , dim = 0 )
687687 return x
688688
689- def patchify (self , x , max_seq , img_sizes = None ):
690- pz2 = self .config .patch_size * self .config .patch_size
691- if isinstance (x , torch .Tensor ):
692- B , C = x .shape [0 ], x .shape [1 ]
693- device = x .device
694- dtype = x .dtype
689+ def patchify (self , hidden_states ):
690+ batch_size , channels , height , width = hidden_states .shape
691+ patch_size = self .config .patch_size
692+ patch_height , patch_width = height // patch_size , width // patch_size
693+ device = hidden_states .device
694+ dtype = hidden_states .dtype
695+
696+ # create img_sizes
697+ img_sizes = torch .tensor ([patch_height , patch_width ], dtype = torch .int64 , device = device ).reshape (- 1 )
698+ img_sizes = img_sizes .unsqueeze (0 ).repeat (batch_size , 1 )
699+
700+ # create hidden_states_masks
701+ if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
702+ hidden_states_masks = torch .zeros ((batch_size , self .max_seq ), dtype = dtype , device = device )
703+ hidden_states_masks [:, : patch_height * patch_width ] = 1.0
695704 else :
696- B , C = len (x ), x [0 ].shape [0 ]
697- device = x [0 ].device
698- dtype = x [0 ].dtype
699- x_masks = torch .zeros ((B , max_seq ), dtype = dtype , device = device )
705+ hidden_states_masks = None
706+
707+ # create img_ids
708+ img_ids = torch .zeros (patch_height , patch_width , 3 , device = device )
709+ row_indices = torch .arange (patch_height , device = device )[:, None ]
710+ col_indices = torch .arange (patch_width , device = device )[None , :]
711+ img_ids [..., 1 ] = img_ids [..., 1 ] + row_indices
712+ img_ids [..., 2 ] = img_ids [..., 2 ] + col_indices
713+ img_ids = img_ids .reshape (patch_height * patch_width , - 1 )
714+
715+ if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
716+ # Handle non-square latents
717+ img_ids_pad = torch .zeros (self .max_seq , 3 , device = device )
718+ img_ids_pad [: patch_height * patch_width , :] = img_ids
719+ img_ids = img_ids_pad .unsqueeze (0 ).repeat (batch_size , 1 , 1 )
720+ else :
721+ img_ids = img_ids .unsqueeze (0 ).repeat (batch_size , 1 , 1 )
722+
723+ # patchify hidden_states
724+ if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
725+ # Handle non-square latents
726+ out = torch .zeros (
727+ (batch_size , channels , self .max_seq , patch_size * patch_size ),
728+ dtype = dtype ,
729+ device = device ,
730+ )
731+ hidden_states = hidden_states .reshape (
732+ batch_size , channels , patch_height , patch_size , patch_width , patch_size
733+ )
734+ hidden_states = hidden_states .permute (0 , 1 , 2 , 4 , 3 , 5 )
735+ hidden_states = hidden_states .reshape (
736+ batch_size , channels , patch_height * patch_width , patch_size * patch_size
737+ )
738+ out [:, :, 0 : patch_height * patch_width ] = hidden_states
739+ hidden_states = out
740+ hidden_states = hidden_states .permute (0 , 2 , 3 , 1 ).reshape (
741+ batch_size , self .max_seq , patch_size * patch_size * channels
742+ )
700743
701- if img_sizes is not None :
702- for i , img_size in enumerate (img_sizes ):
703- x_masks [i , 0 : img_size [0 ] * img_size [1 ]] = 1
704- B , C , S , _ = x .shape
705- x = x .permute (0 , 2 , 3 , 1 ).reshape (B , S , pz2 * C )
706- elif isinstance (x , torch .Tensor ):
707- B , C , Hp1 , Wp2 = x .shape
708- pH , pW = Hp1 // self .config .patch_size , Wp2 // self .config .patch_size
709- x = x .reshape (B , C , pH , self .config .patch_size , pW , self .config .patch_size )
710- x = x .permute (0 , 2 , 4 , 3 , 5 , 1 )
711- x = x .reshape (B , pH * pW , self .config .patch_size * self .config .patch_size * C )
712- img_sizes = [[pH , pW ]] * B
713- x_masks = None
714744 else :
715- raise NotImplementedError
716- return x , x_masks , img_sizes
745+ # Handle square latents
746+ hidden_states = hidden_states .reshape (
747+ batch_size , channels , patch_height , patch_size , patch_width , patch_size
748+ )
749+ hidden_states = hidden_states .permute (0 , 2 , 4 , 3 , 5 , 1 )
750+ hidden_states = hidden_states .reshape (
751+ batch_size , patch_height * patch_width , patch_size * patch_size * channels
752+ )
753+
754+ return hidden_states , hidden_states_masks , img_sizes , img_ids
717755
718756 def forward (
719757 self ,
720758 hidden_states : torch .Tensor ,
721759 timesteps : torch .LongTensor = None ,
722- encoder_hidden_states : torch .Tensor = None ,
760+ encoder_hidden_states_t5 : torch .Tensor = None ,
761+ encoder_hidden_states_llama3 : torch .Tensor = None ,
723762 pooled_embeds : torch .Tensor = None ,
724- img_sizes : Optional [List [Tuple [int , int ]]] = None ,
725763 img_ids : Optional [torch .Tensor ] = None ,
764+ img_sizes : Optional [List [Tuple [int , int ]]] = None ,
765+ hidden_states_masks : Optional [torch .Tensor ] = None ,
726766 attention_kwargs : Optional [Dict [str , Any ]] = None ,
727767 return_dict : bool = True ,
768+ ** kwargs ,
728769 ):
770+ encoder_hidden_states = kwargs .get ("encoder_hidden_states" , None )
771+
772+ if encoder_hidden_states is not None :
773+ deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
774+ deprecate ("encoder_hidden_states" , "0.34.0" , deprecation_message )
775+ encoder_hidden_states_t5 = encoder_hidden_states [0 ]
776+ encoder_hidden_states_llama3 = encoder_hidden_states [1 ]
777+
778+ if img_ids is not None and img_sizes is not None and hidden_states_masks is None :
779+ deprecation_message = (
780+ "Passing `img_ids` and `img_sizes` with unpachified `hidden_states` is deprecated and will be ignored."
781+ )
782+ deprecate ("img_ids" , "0.34.0" , deprecation_message )
783+
784+ if hidden_states_masks is not None and (img_ids is None or img_sizes is None ):
785+ raise ValueError ("if `hidden_states_masks` is passed, `img_ids` and `img_sizes` must also be passed." )
786+ elif hidden_states_masks is not None and hidden_states .ndim != 3 :
787+ raise ValueError (
788+ "if `hidden_states_masks` is passed, `hidden_states` must be a 3D tensors with shape (batch_size, patch_height * patch_width, patch_size * patch_size * channels)"
789+ )
790+
729791 if attention_kwargs is not None :
730792 attention_kwargs = attention_kwargs .copy ()
731793 lora_scale = attention_kwargs .pop ("scale" , 1.0 )
@@ -745,42 +807,19 @@ def forward(
745807 batch_size = hidden_states .shape [0 ]
746808 hidden_states_type = hidden_states .dtype
747809
748- if hidden_states .shape [- 2 ] != hidden_states .shape [- 1 ]:
749- B , C , H , W = hidden_states .shape
750- patch_size = self .config .patch_size
751- pH , pW = H // patch_size , W // patch_size
752- out = torch .zeros (
753- (B , C , self .max_seq , patch_size * patch_size ),
754- dtype = hidden_states .dtype ,
755- device = hidden_states .device ,
756- )
757- hidden_states = hidden_states .reshape (B , C , pH , patch_size , pW , patch_size )
758- hidden_states = hidden_states .permute (0 , 1 , 2 , 4 , 3 , 5 )
759- hidden_states = hidden_states .reshape (B , C , pH * pW , patch_size * patch_size )
760- out [:, :, 0 : pH * pW ] = hidden_states
761- hidden_states = out
810+ # Patchify the input
811+ if hidden_states_masks is None :
812+ hidden_states , hidden_states_masks , img_sizes , img_ids = self .patchify (hidden_states )
813+
814+ # Embed the hidden states
815+ hidden_states = self .x_embedder (hidden_states )
762816
763817 # 0. time
764818 timesteps = self .t_embedder (timesteps , hidden_states_type )
765819 p_embedder = self .p_embedder (pooled_embeds )
766820 temb = timesteps + p_embedder
767821
768- hidden_states , hidden_states_masks , img_sizes = self .patchify (hidden_states , self .max_seq , img_sizes )
769- if hidden_states_masks is None :
770- pH , pW = img_sizes [0 ]
771- img_ids = torch .zeros (pH , pW , 3 , device = hidden_states .device )
772- img_ids [..., 1 ] = img_ids [..., 1 ] + torch .arange (pH , device = hidden_states .device )[:, None ]
773- img_ids [..., 2 ] = img_ids [..., 2 ] + torch .arange (pW , device = hidden_states .device )[None , :]
774- img_ids = (
775- img_ids .reshape (img_ids .shape [0 ] * img_ids .shape [1 ], img_ids .shape [2 ])
776- .unsqueeze (0 )
777- .repeat (batch_size , 1 , 1 )
778- )
779- hidden_states = self .x_embedder (hidden_states )
780-
781- T5_encoder_hidden_states = encoder_hidden_states [0 ]
782- encoder_hidden_states = encoder_hidden_states [- 1 ]
783- encoder_hidden_states = [encoder_hidden_states [k ] for k in self .config .llama_layers ]
822+ encoder_hidden_states = [encoder_hidden_states_llama3 [k ] for k in self .config .llama_layers ]
784823
785824 if self .caption_projection is not None :
786825 new_encoder_hidden_states = []
@@ -789,9 +828,9 @@ def forward(
789828 enc_hidden_state = enc_hidden_state .view (batch_size , - 1 , hidden_states .shape [- 1 ])
790829 new_encoder_hidden_states .append (enc_hidden_state )
791830 encoder_hidden_states = new_encoder_hidden_states
792- T5_encoder_hidden_states = self .caption_projection [- 1 ](T5_encoder_hidden_states )
793- T5_encoder_hidden_states = T5_encoder_hidden_states .view (batch_size , - 1 , hidden_states .shape [- 1 ])
794- encoder_hidden_states .append (T5_encoder_hidden_states )
831+ encoder_hidden_states_t5 = self .caption_projection [- 1 ](encoder_hidden_states_t5 )
832+ encoder_hidden_states_t5 = encoder_hidden_states_t5 .view (batch_size , - 1 , hidden_states .shape [- 1 ])
833+ encoder_hidden_states .append (encoder_hidden_states_t5 )
795834
796835 txt_ids = torch .zeros (
797836 batch_size ,
0 commit comments