@@ -594,24 +594,25 @@ def __call__(
594594 batch_size = prompt_embeds .shape [0 ]
595595
596596 device = self ._execution_device
597- # 3. Prepare text embeddings
598- (
599- prompt_embeds ,
600- prompt_attention_mask ,
601- negative_prompt_embeds ,
602- negative_prompt_attention_mask ,
603- ) = self .encode_prompt (
604- prompt = prompt ,
605- negative_prompt = negative_prompt ,
606- do_classifier_free_guidance = self .do_classifier_free_guidance ,
607- num_videos_per_prompt = num_videos_per_prompt ,
608- prompt_embeds = prompt_embeds ,
609- negative_prompt_embeds = negative_prompt_embeds ,
610- prompt_attention_mask = prompt_attention_mask ,
611- negative_prompt_attention_mask = negative_prompt_attention_mask ,
612- max_sequence_length = max_sequence_length ,
613- device = device ,
614- )
597+ with torch .autocast ("cuda" , torch .float32 ):
598+ # 3. Prepare text embeddings
599+ (
600+ prompt_embeds ,
601+ prompt_attention_mask ,
602+ negative_prompt_embeds ,
603+ negative_prompt_attention_mask ,
604+ ) = self .encode_prompt (
605+ prompt = prompt ,
606+ negative_prompt = negative_prompt ,
607+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
608+ num_videos_per_prompt = num_videos_per_prompt ,
609+ prompt_embeds = prompt_embeds ,
610+ negative_prompt_embeds = negative_prompt_embeds ,
611+ prompt_attention_mask = prompt_attention_mask ,
612+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
613+ max_sequence_length = max_sequence_length ,
614+ device = device ,
615+ )
615616 # 4. Prepare latent variables
616617 num_channels_latents = self .transformer .config .in_channels
617618 latents = self .prepare_latents (
0 commit comments