@@ -594,25 +594,24 @@ def __call__(
594594 batch_size = prompt_embeds .shape [0 ]
595595
596596 device = self ._execution_device
597- with torch .autocast ("cuda" , torch .float32 ):
598- # 3. Prepare text embeddings
599- (
600- prompt_embeds ,
601- prompt_attention_mask ,
602- negative_prompt_embeds ,
603- negative_prompt_attention_mask ,
604- ) = self .encode_prompt (
605- prompt = prompt ,
606- negative_prompt = negative_prompt ,
607- do_classifier_free_guidance = self .do_classifier_free_guidance ,
608- num_videos_per_prompt = num_videos_per_prompt ,
609- prompt_embeds = prompt_embeds ,
610- negative_prompt_embeds = negative_prompt_embeds ,
611- prompt_attention_mask = prompt_attention_mask ,
612- negative_prompt_attention_mask = negative_prompt_attention_mask ,
613- max_sequence_length = max_sequence_length ,
614- device = device ,
615- )
597+ # 3. Prepare text embeddings
598+ (
599+ prompt_embeds ,
600+ prompt_attention_mask ,
601+ negative_prompt_embeds ,
602+ negative_prompt_attention_mask ,
603+ ) = self .encode_prompt (
604+ prompt = prompt ,
605+ negative_prompt = negative_prompt ,
606+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
607+ num_videos_per_prompt = num_videos_per_prompt ,
608+ prompt_embeds = prompt_embeds ,
609+ negative_prompt_embeds = negative_prompt_embeds ,
610+ prompt_attention_mask = prompt_attention_mask ,
611+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
612+ max_sequence_length = max_sequence_length ,
613+ device = device ,
614+ )
616615 # 4. Prepare latent variables
617616 num_channels_latents = self .transformer .config .in_channels
618617 latents = self .prepare_latents (
0 commit comments