|
1 | | -# Copyright 2025 The HuggingFace Team. All rights reserved. |
| 1 | +# Copyright 2025 The DEVAIEXP Team and The HuggingFace Team. All rights reserved. |
2 | 2 | # |
3 | 3 | # Licensed under the Apache License, Version 2.0 (the "License"); |
4 | 4 | # you may not use this file except in compliance with the License. |
@@ -1070,32 +1070,32 @@ def __call__( |
1070 | 1070 | text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) |
1071 | 1071 | else: |
1072 | 1072 | text_encoder_projection_dim = self.text_encoder_2.config.projection_dim |
1073 | | - add_time_ids = self._get_add_time_ids( |
1074 | | - original_size, |
1075 | | - crops_coords_top_left[row][col], |
1076 | | - target_size, |
| 1073 | + add_time_ids = self._get_add_time_ids( |
| 1074 | + original_size, |
| 1075 | + crops_coords_top_left[row][col], |
| 1076 | + target_size, |
| 1077 | + dtype=prompt_embeds.dtype, |
| 1078 | + text_encoder_projection_dim=text_encoder_projection_dim, |
| 1079 | + ) |
| 1080 | + if negative_original_size is not None and negative_target_size is not None: |
| 1081 | + negative_add_time_ids = self._get_add_time_ids( |
| 1082 | + negative_original_size, |
| 1083 | + negative_crops_coords_top_left[row][col], |
| 1084 | + negative_target_size, |
1077 | 1085 | dtype=prompt_embeds.dtype, |
1078 | 1086 | text_encoder_projection_dim=text_encoder_projection_dim, |
1079 | 1087 | ) |
1080 | | - if negative_original_size is not None and negative_target_size is not None: |
1081 | | - negative_add_time_ids = self._get_add_time_ids( |
1082 | | - negative_original_size, |
1083 | | - negative_crops_coords_top_left[row][col], |
1084 | | - negative_target_size, |
1085 | | - dtype=prompt_embeds.dtype, |
1086 | | - text_encoder_projection_dim=text_encoder_projection_dim, |
1087 | | - ) |
1088 | | - else: |
1089 | | - negative_add_time_ids = add_time_ids |
| 1088 | + else: |
| 1089 | + negative_add_time_ids = add_time_ids |
1090 | 1090 |
|
1091 | | - if self.do_classifier_free_guidance: |
1092 | | - prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
1093 | | - add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
1094 | | - add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) |
| 1091 | + if self.do_classifier_free_guidance: |
| 1092 | + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0) |
| 1093 | + add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
| 1094 | + add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0) |
1095 | 1095 |
|
1096 | | - prompt_embeds = prompt_embeds.to(device) |
1097 | | - add_text_embeds = add_text_embeds.to(device) |
1098 | | - add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) |
| 1096 | + prompt_embeds = prompt_embeds.to(device) |
| 1097 | + add_text_embeds = add_text_embeds.to(device) |
| 1098 | + add_time_ids = add_time_ids.to(device).repeat(batch_size * num_images_per_prompt, 1) |
1099 | 1099 | addition_embed_type_row.append((prompt_embeds, add_text_embeds, add_time_ids)) |
1100 | 1100 | embeddings_and_added_time.append(addition_embed_type_row) |
1101 | 1101 |
|
|
0 commit comments