@@ -601,25 +601,24 @@ def __call__(
601601 batch_size = prompt_embeds .shape [0 ]
602602
603603 device = self ._execution_device
604- with torch .autocast ("cuda" , torch .float32 ):
605- # 3. Prepare text embeddings
606- (
607- prompt_embeds ,
608- prompt_attention_mask ,
609- negative_prompt_embeds ,
610- negative_prompt_attention_mask ,
611- ) = self .encode_prompt (
612- prompt = prompt ,
613- negative_prompt = negative_prompt ,
614- do_classifier_free_guidance = self .do_classifier_free_guidance ,
615- num_videos_per_prompt = num_videos_per_prompt ,
616- prompt_embeds = prompt_embeds ,
617- negative_prompt_embeds = negative_prompt_embeds ,
618- prompt_attention_mask = prompt_attention_mask ,
619- negative_prompt_attention_mask = negative_prompt_attention_mask ,
620- max_sequence_length = max_sequence_length ,
621- device = device ,
622- )
604+ # 3. Prepare text embeddings
605+ (
606+ prompt_embeds ,
607+ prompt_attention_mask ,
608+ negative_prompt_embeds ,
609+ negative_prompt_attention_mask ,
610+ ) = self .encode_prompt (
611+ prompt = prompt ,
612+ negative_prompt = negative_prompt ,
613+ do_classifier_free_guidance = self .do_classifier_free_guidance ,
614+ num_videos_per_prompt = num_videos_per_prompt ,
615+ prompt_embeds = prompt_embeds ,
616+ negative_prompt_embeds = negative_prompt_embeds ,
617+ prompt_attention_mask = prompt_attention_mask ,
618+ negative_prompt_attention_mask = negative_prompt_attention_mask ,
619+ max_sequence_length = max_sequence_length ,
620+ device = device ,
621+ )
623622 # if self.do_classifier_free_guidance:
624623 # prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
625624 # prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
0 commit comments