Skip to content

Commit 1b7fb36

Browse files
Review updates
1 parent 6ac5cbb commit 1b7fb36

File tree

1 file changed

+73
-51
lines changed

1 file changed

+73
-51
lines changed

src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py

Lines changed: 73 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,31 @@ def __init__(
103103
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
104104
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
105105

106+
@staticmethod
107+
def calculate_shift(
108+
image_seq_len,
109+
base_seq_len: int = 256,
110+
max_seq_len: int = 4096,
111+
base_shift: float = 0.5,
112+
max_shift: float = 1.15,
113+
):
114+
"""Calculate shift parameter based on image dimensions.
115+
116+
Args:
117+
image_seq_len: Length of the image sequence (height/vae_factor/2 * width/vae_factor/2)
118+
base_seq_len: Base sequence length for interpolation
119+
max_seq_len: Maximum sequence length for interpolation
120+
base_shift: Base shift value
121+
max_shift: Maximum shift value
122+
123+
Returns:
124+
Calculated shift parameter (mu)
125+
"""
126+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
127+
b = base_shift - m * base_seq_len
128+
mu = image_seq_len * m + b
129+
return mu
130+
106131
def check_inputs(
107132
self,
108133
prompt,
@@ -305,41 +330,8 @@ def encode_prompt(
305330

306331
return prompt_embeds, prompt_attention_mask, negative_prompt_embeds, negative_prompt_attention_mask
307332

308-
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_latents
309-
def prepare_latents(
310-
self,
311-
batch_size,
312-
num_channels_latents,
313-
height,
314-
width,
315-
dtype,
316-
device,
317-
generator,
318-
latents=None,
319-
):
320-
if latents is not None:
321-
return latents.to(device=device, dtype=dtype)
322-
323-
shape = (
324-
batch_size,
325-
num_channels_latents,
326-
int(height) // self.vae_scale_factor,
327-
int(width) // self.vae_scale_factor,
328-
)
329-
330-
if isinstance(generator, list) and len(generator) != batch_size:
331-
raise ValueError(
332-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
333-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
334-
)
335-
336-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
337-
338-
return latents
339-
340333
def get_timesteps(self, num_inference_steps, strength, device):
341334
# Set timesteps using the full range initially
342-
self.scheduler.set_timesteps(num_inference_steps, device=device)
343335
timesteps = self.scheduler.timesteps.to(device=device)
344336

345337
if len(timesteps) != num_inference_steps:
@@ -349,18 +341,29 @@ def get_timesteps(self, num_inference_steps, strength, device):
349341
init_timestep = min(num_inference_steps * strength, num_inference_steps)
350342

351343
t_start = int(max(num_inference_steps - init_timestep, 0))
352-
timesteps = self.scheduler.timesteps[t_start:]
344+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:]
345+
346+
# Set begin index if scheduler supports it
347+
if hasattr(self.scheduler, "set_begin_index"):
348+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
353349

354350
return timesteps, num_inference_steps - t_start
355351

356-
def prepare_img2img_latents(
352+
def prepare_latents(
357353
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
358354
):
359355
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
360356
raise ValueError(
361357
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
362358
)
363359

360+
# Check for latents_mean and latents_std in the VAE config
361+
latents_mean = latents_std = None
362+
if hasattr(self.vae.config, "latents_mean") and self.vae.config.latents_mean is not None:
363+
latents_mean = torch.tensor(self.vae.config.latents_mean).view(1, 4, 1, 1)
364+
if hasattr(self.vae.config, "latents_std") and self.vae.config.latents_std is not None:
365+
latents_std = torch.tensor(self.vae.config.latents_std).view(1, 4, 1, 1)
366+
364367
image = image.to(device=device, dtype=dtype)
365368

366369
batch_size = batch_size * num_images_per_prompt
@@ -404,26 +407,30 @@ def prepare_img2img_latents(
404407
if isinstance(generator, list):
405408
sample = torch.cat(
406409
[
407-
torch.randn(
410+
randn_tensor(
408411
(1, *mean.shape[1:]),
409412
generator=generator[i],
410-
device=generator[i].device if hasattr(generator[i], "device") else "cpu",
413+
device=mean.device,
411414
dtype=mean.dtype,
412-
).to(mean.device)
415+
)
413416
for i in range(batch_size)
414417
]
415418
)
416419
else:
417420
# Single generator - use its device if it has one
418-
generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu"
419-
noise = torch.randn(mean.shape, generator=generator, device=generator_device, dtype=mean.dtype)
420-
sample = noise.to(mean.device)
421+
sample = randn_tensor(mean.shape, generator=generator, device=mean.device, dtype=mean.dtype)
421422

422423
# Compute latents
423424
latents = mean + std * sample
424425

425-
# Scale latents
426-
latents = latents * self.vae.config.scaling_factor
426+
# Apply standardization if VAE has mean and std defined in config
427+
if latents_mean is not None and latents_std is not None:
428+
latents_mean = latents_mean.to(device=device, dtype=dtype)
429+
latents_std = latents_std.to(device=device, dtype=dtype)
430+
latents = (latents - latents_mean) * self.vae.config.scaling_factor / latents_std
431+
else:
432+
# Scale latents
433+
latents = latents * self.vae.config.scaling_factor
427434

428435
# get the original timestep using init_timestep
429436
init_timestep = timestep # Use the passed timestep directly
@@ -433,21 +440,20 @@ def prepare_img2img_latents(
433440
if isinstance(generator, list):
434441
noise = torch.cat(
435442
[
436-
torch.randn(
443+
randn_tensor(
437444
(1, *latents.shape[1:]),
438445
generator=generator[i],
439-
device=generator[i].device if hasattr(generator[i], "device") else "cpu",
446+
device=latents.device,
440447
dtype=latents.dtype,
441-
).to(latents.device)
448+
)
442449
for i in range(batch_size)
443450
]
444451
)
445452
else:
446453
# Single generator - use its device if it has one
447-
generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu"
448-
noise = torch.randn(
449-
latents.shape, generator=generator, device=generator_device, dtype=latents.dtype
450-
).to(latents.device)
454+
noise = randn_tensor(
455+
latents.shape, generator=generator, device=latents.device, dtype=latents.dtype
456+
)
451457

452458
latents = self.scheduler.scale_noise(latents, init_timestep, noise)
453459

@@ -654,13 +660,29 @@ def __call__(
654660
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
655661

656662
# 5. Prepare timesteps
663+
# Calculate shift parameter based on image dimensions
664+
image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
665+
666+
# Calculate mu (shift parameter) based on image dimensions
667+
mu = self.calculate_shift(
668+
image_seq_len,
669+
self.scheduler.config.get("base_image_seq_len", 256),
670+
self.scheduler.config.get("max_image_seq_len", 4096),
671+
self.scheduler.config.get("base_shift", 0.5),
672+
self.scheduler.config.get("max_shift", 1.15),
673+
)
674+
675+
# Set timesteps with shift parameter
676+
self.scheduler.set_timesteps(num_inference_steps, device=device, mu=mu)
677+
678+
# Now adjust for strength
657679
timesteps, num_inference_steps = self.get_timesteps(
658680
num_inference_steps, strength, device
659681
)
660682
latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # Get the first timestep(s) for initial noise
661683

662684
# 6. Prepare latent variables
663-
latents = self.prepare_img2img_latents(
685+
latents = self.prepare_latents(
664686
image,
665687
latent_timestep,
666688
batch_size,

0 commit comments

Comments
 (0)