@@ -139,6 +139,7 @@ def retrieve_timesteps(
139139
140140class MochiPipeline (
141141 DiffusionPipeline ,
142+ TextualInversionLoaderMixin
142143):
143144 r"""
144145 The Flux pipeline for text-to-image generation.
@@ -187,14 +188,17 @@ def __init__(
187188 transformer = transformer ,
188189 scheduler = scheduler ,
189190 )
190- self .vae_scale_factor = (
191- 2 ** (len (self .vae .config .block_out_channels )) if hasattr (self , "vae" ) and self .vae is not None else 16
192- )
191+ #TODO: determine these scaling factors from model parameters
192+ self .vae_spatial_scale_factor = 8
193+ self .vae_temporal_scale_factor = 6
194+ self .patch_size = 2
195+
193196 self .image_processor = VaeImageProcessor (vae_scale_factor = self .vae_scale_factor )
194197 self .tokenizer_max_length = (
195198 self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
196199 )
197- self .default_sample_size = 64
200+ self .default_height = 64
201+ self .default_width = 64
198202
199203 def _get_t5_prompt_embeds (
200204 self ,
@@ -235,7 +239,7 @@ def _get_t5_prompt_embeds(
235239 f" { max_sequence_length } tokens: { removed_text } "
236240 )
237241
238- prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = prompt_attention_mask , output_hidden_states = False )[ 0 ]
242+ prompt_embeds = self .text_encoder (text_input_ids .to (device ), attention_mask = prompt_attention_mask , output_hidden_states = False ). last_hidden_state
239243
240244 dtype = self .text_encoder .dtype
241245 prompt_embeds = prompt_embeds .to (dtype = dtype , device = device )
@@ -246,7 +250,32 @@ def _get_t5_prompt_embeds(
246250 prompt_embeds = prompt_embeds .repeat (1 , num_images_per_prompt , 1 )
247251 prompt_embeds = prompt_embeds .view (batch_size * num_images_per_prompt , seq_len , - 1 )
248252
249- return prompt_embeds
253+ return prompt_embeds , prompt_attention_mask
254+
255+ def _pack_indices (self , attention_mask , latent_frames_dim , latent_height_dim , latent_width_dim ):
256+ N = latent_frames_dim * latent_height_dim * latent_width_dim // (self .patch_size ** 2 )
257+
258+ # Create an expanded token mask saying which tokens are valid across both visual and text tokens.
259+ assert N > 0 and len (attention_mask ) == 1
260+ attention_mask = attention_mask [0 ]
261+
262+ mask = F .pad (attention_mask , (N , 0 ), value = True ) # (B, N + L)
263+ seqlens_in_batch = mask .sum (dim = - 1 , dtype = torch .int32 ) # (B,)
264+ valid_token_indices = torch .nonzero (
265+ mask .flatten (), as_tuple = False
266+ ).flatten () # up to (B * (N + L),)
267+
268+ assert valid_token_indices .size (0 ) >= attention_mask .size (0 ) * N # At least (B * N,)
269+ cu_seqlens = F .pad (
270+ torch .cumsum (seqlens_in_batch , dim = 0 , dtype = torch .torch .int32 ), (1 , 0 )
271+ )
272+ max_seqlen_in_batch = seqlens_in_batch .max ().item ()
273+
274+ return {
275+ "cu_seqlens_kv" : cu_seqlens ,
276+ "max_seqlen_in_batch_kv" : max_seqlen_in_batch ,
277+ "valid_token_indices_kv" : valid_token_indices ,
278+ }
250279
251280 def encode_prompt (
252281 self ,
0 commit comments