Skip to content

Commit 10d798a

Browse files
committed
fix
1 parent 6e05c21 commit 10d798a

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

src/diffusers/pipelines/ovis_image/pipeline_ovis_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
"""
5454

5555

56+
# Copied from diffusers.pipelines.flux.pipeline_flux_utils.calculate_shift
5657
def calculate_shift(
5758
image_seq_len,
5859
base_seq_len: int = 256,
@@ -66,7 +67,7 @@ def calculate_shift(
6667
return mu
6768

6869

69-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
70+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_utils.retrieve_timesteps
7071
def retrieve_timesteps(
7172
scheduler,
7273
num_inference_steps: Optional[int] = None,

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,24 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
309309
latents = latents * self.scheduler.init_noise_sigma
310310
return latents
311311

312+
def _get_add_time_ids(
313+
self, original_size, crops_coords_top_left, target_size, dtype, text_encoder_projection_dim=None
314+
):
315+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
316+
317+
passed_add_embed_dim = (
318+
self.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
319+
)
320+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
321+
322+
if expected_add_embed_dim != passed_add_embed_dim:
323+
raise ValueError(
324+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
325+
)
326+
327+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
328+
return add_time_ids
329+
312330
@torch.no_grad()
313331
@replace_example_docstring(EXAMPLE_DOC_STRING)
314332
def __call__(

0 commit comments

Comments
 (0)