Skip to content

Commit 1f4adb9

Browse files
committed
fix use of variable latents to tile_latents
1 parent 871f333 commit 1f4adb9

File tree

1 file changed

+10
-10
lines changed

1 file changed

+10
-10
lines changed

examples/community/mixture_tiling_sdxl.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1080,7 +1080,7 @@ def __call__(
10801080
tile_latents = latents[:, :, px_row_init:px_row_end, px_col_init:px_col_end]
10811081
# expand the latents if we are doing classifier free guidance
10821082
latent_model_input = (
1083-
torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else latents
1083+
torch.cat([tile_latents] * 2) if self.do_classifier_free_guidance else tile_latents
10841084
)
10851085
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
10861086

@@ -1089,15 +1089,15 @@ def __call__(
10891089
"text_embeds": embeddings_and_added_time[row][col][1],
10901090
"time_ids": embeddings_and_added_time[row][col][2],
10911091
}
1092-
with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
1093-
noise_pred = self.unet(
1094-
latent_model_input,
1095-
t,
1096-
encoder_hidden_states=embeddings_and_added_time[row][col][0],
1097-
cross_attention_kwargs=self.cross_attention_kwargs,
1098-
added_cond_kwargs=added_cond_kwargs,
1099-
return_dict=False,
1100-
)[0]
1092+
#with torch.amp.autocast(device.type, dtype=dtype, enabled=dtype != self.unet.dtype):
1093+
noise_pred = self.unet(
1094+
latent_model_input,
1095+
t,
1096+
encoder_hidden_states=embeddings_and_added_time[row][col][0],
1097+
cross_attention_kwargs=self.cross_attention_kwargs,
1098+
added_cond_kwargs=added_cond_kwargs,
1099+
return_dict=False,
1100+
)[0]
11011101

11021102
# perform guidance
11031103
if self.do_classifier_free_guidance:

0 commit comments

Comments
 (0)