@@ -188,6 +188,7 @@ def __init__(
188188 text_encoder : T5EncoderModel ,
189189 tokenizer : T5TokenizerFast ,
190190 transformer : MochiTransformer3DModel ,
191+ force_zeros_for_empty_prompt : bool = False ,
191192 ):
192193 super ().__init__ ()
193194
@@ -205,10 +206,11 @@ def __init__(
205206
206207 self .video_processor = VideoProcessor (vae_scale_factor = self .vae_spatial_scale_factor )
207208 self .tokenizer_max_length = (
208- self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 77
209+ self .tokenizer .model_max_length if hasattr (self , "tokenizer" ) and self .tokenizer is not None else 256
209210 )
210211 self .default_height = 480
211212 self .default_width = 848
213+ self .register_to_config (force_zeros_for_empty_prompt = force_zeros_for_empty_prompt )
212214
213215 def _get_t5_prompt_embeds (
214216 self ,
@@ -236,7 +238,11 @@ def _get_t5_prompt_embeds(
236238 text_input_ids = text_inputs .input_ids
237239 prompt_attention_mask = text_inputs .attention_mask
238240 prompt_attention_mask = prompt_attention_mask .bool ().to (device )
239- if prompt == "" or prompt [- 1 ] == "" :
241+
242+ # The original Mochi implementation zeros out empty negative prompts
243+ # but this can lead to overflow when placing the entire pipeline under the autocast context
244+ # adding this here so that we can enable zeroing prompts if necessary
245+ if self .config .force_zeros_for_empty_prompt and (prompt == "" or prompt [- 1 ] == "" ):
240246 text_input_ids = torch .zeros_like (text_input_ids , device = device )
241247 prompt_attention_mask = torch .zeros_like (prompt_attention_mask , dtype = torch .bool , device = device )
242248
0 commit comments