@@ -601,24 +601,25 @@ def __call__(
601601 batch_size = prompt_embeds .shape [0 ]
602602
603603 device = self ._execution_device
604- # 3. Prepare text embeddings
605- (
606- prompt_embeds ,
607- prompt_attention_mask ,
608- negative_prompt_embeds ,
609- negative_prompt_attention_mask ,
610- ) = self .encode_prompt (
611- prompt = prompt ,
612- negative_prompt = negative_prompt ,
613- do_classifier_free_guidance = self .do_classifier_free_guidance ,
614- num_videos_per_prompt = num_videos_per_prompt ,
615- prompt_embeds = prompt_embeds ,
616- negative_prompt_embeds = negative_prompt_embeds ,
617- prompt_attention_mask = prompt_attention_mask ,
618- negative_prompt_attention_mask = negative_prompt_attention_mask ,
619- max_sequence_length = max_sequence_length ,
620- device = device ,
621- )
604+ with torch .autocast ("cuda" , torch .float32 ):
605+ # 3. Prepare text embeddings
606+ (
607+ prompt_embeds ,
608+ prompt_attention_mask ,
609+ negative_prompt_embeds ,
610+ negative_prompt_attention_mask ,
611+ ) = self .encode_prompt (
612+ prompt = prompt ,
613+ negative_prompt = negative_prompt ,
614+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
615+ num_videos_per_prompt = num_videos_per_prompt ,
616+ prompt_embeds = prompt_embeds ,
617+ negative_prompt_embeds = negative_prompt_embeds ,
618+ prompt_attention_mask = prompt_attention_mask ,
619+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
620+ max_sequence_length = max_sequence_length ,
621+ device = device ,
622+ )
622623 # if self.do_classifier_free_guidance:
623624 # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
624625 # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
0 commit comments