Skip to content

Commit 6ff1af8

Browse files
Update i2i tests, fix style
1 parent 6fb6f2b commit 6ff1af8

File tree

7 files changed

+196
-55
lines changed

7 files changed

+196
-55
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -345,8 +345,8 @@
345345
"AudioLDM2ProjectionModel",
346346
"AudioLDM2UNet2DConditionModel",
347347
"AudioLDMPipeline",
348-
"AuraFlowPipeline",
349348
"AuraFlowImg2ImgPipeline",
349+
"AuraFlowPipeline",
350350
"BlipDiffusionControlNetPipeline",
351351
"BlipDiffusionPipeline",
352352
"CLIPImageProjection",

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -515,7 +515,7 @@
515515
AudioLDM2ProjectionModel,
516516
AudioLDM2UNet2DConditionModel,
517517
)
518-
from .aura_flow import AuraFlowPipeline, AuraFlowImg2ImgPipeline
518+
from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline
519519
from .blip_diffusion import BlipDiffusionPipeline
520520
from .cogvideo import (
521521
CogVideoXFunControlPipeline,

src/diffusers/pipelines/aura_flow/pipeline_aura_flow_img2img.py

Lines changed: 85 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,10 @@
2121
from diffusers.image_processor import VaeImageProcessor
2222
from diffusers.models import AuraFlowTransformer2DModel, AutoencoderKL
2323
from diffusers.models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
24+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2425
from diffusers.schedulers import FlowMatchEulerDiscreteScheduler
2526
from diffusers.utils import is_torch_xla_available, logging, replace_example_docstring
2627
from diffusers.utils.torch_utils import randn_tensor
27-
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput
2828

2929

3030
if is_torch_xla_available():
@@ -119,12 +119,12 @@ def check_inputs(
119119
):
120120
if strength < 0 or strength > 1:
121121
raise ValueError(f"The value of strength should be in [0.0, 1.0] but is {strength}")
122-
122+
123123
patch_size = 2 # AuraFlow uses patch size of 2
124124
required_divisor = self.vae_scale_factor * patch_size
125125
if height % required_divisor != 0 or width % required_divisor != 0:
126126
raise ValueError(
127-
f"\`height\` and \`width\` have to be divisible by the VAE scale factor ({self.vae_scale_factor}) times the transformer patch size ({patch_size}), which is {required_divisor}. "
127+
rf"\`height\` and \`width\` have to be divisible by the VAE scale factor ({self.vae_scale_factor}) times the transformer patch size ({patch_size}), which is {required_divisor}. "
128128
f"Your dimensions are ({height}, {width})."
129129
)
130130

@@ -339,7 +339,7 @@ def prepare_latents(
339339

340340
def get_timesteps(self, num_inference_steps, strength, device):
341341
# 1. Call set_timesteps with num_inference_steps
342-
self.scheduler.set_timesteps(num_inference_steps, device=device) # Ensure scheduler uses the correct number of steps
342+
self.scheduler.set_timesteps(num_inference_steps, device=device)
343343

344344
# 2. Calculate strength-based number of steps and offset
345345
init_timestep_count = min(int(num_inference_steps * strength), num_inference_steps)
@@ -353,14 +353,7 @@ def get_timesteps(self, num_inference_steps, strength, device):
353353
return timesteps, num_actual_inference_steps
354354

355355
def prepare_img2img_latents(
356-
self,
357-
image,
358-
timestep,
359-
batch_size,
360-
num_images_per_prompt,
361-
dtype,
362-
device,
363-
generator=None
356+
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None
364357
):
365358
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
366359
raise ValueError(
@@ -380,34 +373,87 @@ def prepare_img2img_latents(
380373
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
381374
)
382375

383-
if image.shape[0] == 1:
384-
image = image.repeat(batch_size, 1, 1, 1)
376+
# Handle different batch size scenarios
377+
if image.shape[0] < batch_size:
378+
if batch_size % image.shape[0] == 0:
379+
# Duplicate the image to match the batch size
380+
additional_image_per_prompt = batch_size // image.shape[0]
381+
image = torch.cat([image] * additional_image_per_prompt, dim=0)
382+
else:
383+
raise ValueError(
384+
f"Cannot duplicate `image` of batch size {image.shape[0]} to {batch_size} text prompts."
385+
f" Batch size must be divisible by the image batch size."
386+
)
385387

386388
# encode the init image into latents and scale the latents
387-
latents = self.vae.encode(image).latent_dist.sample(generator=generator)
389+
# 1. Get VAE distribution parameters (on device)
390+
latent_dist = self.vae.encode(image).latent_dist
391+
mean, std = latent_dist.mean, latent_dist.std # Already on device
392+
393+
# 2. Sample noise for each batch element individually if using multiple generators
394+
if isinstance(generator, list):
395+
sample = torch.cat(
396+
[
397+
torch.randn(
398+
(1, *mean.shape[1:]),
399+
generator=generator[i],
400+
device=generator[i].device if hasattr(generator[i], "device") else "cpu",
401+
dtype=mean.dtype,
402+
).to(mean.device)
403+
for i in range(batch_size)
404+
]
405+
)
406+
else:
407+
# Single generator - use its device if it has one
408+
generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu"
409+
noise = torch.randn(mean.shape, generator=generator, device=generator_device, dtype=mean.dtype)
410+
sample = noise.to(mean.device)
411+
412+
# Compute latents
413+
latents = mean + std * sample
414+
415+
# Scale latents
388416
latents = latents * self.vae.config.scaling_factor
389417

390418
# get the original timestep using init_timestep
391419
init_timestep = timestep
392420

393421
# add noise to latents using the timesteps
394-
noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype)
395-
422+
# Handle noise generation with multiple generators if provided
423+
if isinstance(generator, list):
424+
noise = torch.cat(
425+
[
426+
torch.randn(
427+
(1, *latents.shape[1:]),
428+
generator=generator[i],
429+
device=generator[i].device if hasattr(generator[i], "device") else "cpu",
430+
dtype=latents.dtype,
431+
).to(latents.device)
432+
for i in range(batch_size)
433+
]
434+
)
435+
else:
436+
# Single generator - use its device if it has one
437+
generator_device = getattr(generator, "device", "cpu") if generator is not None else "cpu"
438+
noise = torch.randn(
439+
latents.shape, generator=generator, device=generator_device, dtype=latents.dtype
440+
).to(latents.device)
441+
396442
# Ensure timestep tensor is on the same device
397443
t = init_timestep.to(latents.device)
398-
444+
399445
# Normalize timestep to [0, 1] range (using scheduler's config)
400446
t = t / self.scheduler.config.num_train_timesteps
401-
447+
402448
# Reshape t to match the dimensions needed for broadcasting
403449
required_dims = len(latents.shape)
404450
current_dims = len(t.shape)
405451
for _ in range(required_dims - current_dims):
406452
t = t.unsqueeze(-1)
407-
453+
408454
# Interpolation: x_t = t * x_1 + (1 - t) * x_0
409455
latents = t * noise + (1 - t) * latents
410-
456+
411457
return latents
412458

413459
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.upcast_vae
@@ -606,13 +652,14 @@ def __call__(
606652
negative_prompt_attention_mask=negative_prompt_attention_mask,
607653
max_sequence_length=max_sequence_length,
608654
)
655+
609656
if do_classifier_free_guidance:
610657
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
611658

612659
# 5. Prepare timesteps
613660
timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
614661
latent_timestep = timesteps[:1]
615-
662+
616663
# 6. Prepare latent variables
617664
latents = self.prepare_img2img_latents(
618665
image,
@@ -632,10 +679,13 @@ def __call__(
632679
# expand the latents if we are doing classifier free guidance
633680
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
634681

635-
# aura use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
636-
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
637-
timestep = torch.tensor([t / 1000]).expand(latent_model_input.shape[0])
638-
timestep = timestep.to(latents.device, dtype=latents.dtype)
682+
# AureFlow use timestep value between 0 and 1, with t=1 as noise and t=0 as the image
683+
# create a timestep tensor with the correct batch size
684+
# ensure it matches the batch size of the model input
685+
t_float = t / 1000
686+
timestep_tensor = torch.full(
687+
(latent_model_input.shape[0],), t_float, device=latents.device, dtype=latents.dtype
688+
)
639689

640690
# Make sure latent_model_input has the same dtype as the transformer
641691
transformer_dtype = self.transformer.dtype
@@ -646,7 +696,7 @@ def __call__(
646696
noise_pred = self.transformer(
647697
latent_model_input,
648698
encoder_hidden_states=prompt_embeds,
649-
timestep=timestep,
699+
timestep=timestep_tensor,
650700
return_dict=False,
651701
)[0]
652702

@@ -682,15 +732,19 @@ def __call__(
682732
if needs_upcasting:
683733
self.upcast_vae()
684734
latents = latents.to(next(iter(self.vae.post_quant_conv.parameters())).dtype)
685-
735+
686736
# Apply proper scaling factor and shift factor if available
687-
if hasattr(self.vae.config, "scaling_factor") and hasattr(self.vae.config, "shift_factor") and getattr(self.vae.config, "shift_factor", None) is not None:
737+
if (
738+
hasattr(self.vae.config, "scaling_factor")
739+
and hasattr(self.vae.config, "shift_factor")
740+
and getattr(self.vae.config, "shift_factor", None) is not None
741+
):
688742
# Handle both scaling and shifting
689743
latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
690744
else:
691745
# Just scale using standard approach
692746
latents = latents / self.vae.config.scaling_factor
693-
747+
694748
image = self.vae.decode(latents, return_dict=False)[0]
695749
image = self.image_processor.postprocess(image, output_type=output_type)
696750

@@ -700,4 +754,4 @@ def __call__(
700754
if not return_dict:
701755
return (image,)
702756

703-
return ImagePipelineOutput(images=image)
757+
return ImagePipelineOutput(images=image)

src/diffusers/pipelines/auto_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..configuration_utils import ConfigMixin
2121
from ..models.controlnets import ControlNetUnionModel
2222
from ..utils import is_sentencepiece_available
23-
from .aura_flow import AuraFlowPipeline, AuraFlowImg2ImgPipeline
23+
from .aura_flow import AuraFlowImg2ImgPipeline, AuraFlowPipeline
2424
from .cogview3 import CogView3PlusPipeline
2525
from .cogview4 import CogView4ControlPipeline, CogView4Pipeline
2626
from .controlnet import (
@@ -165,7 +165,7 @@
165165
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
166166
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
167167
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
168-
("auraflow", AuraFlowImg2ImgPipeline),
168+
("auraflow", AuraFlowImg2ImgPipeline),
169169
("lcm", LatentConsistencyModelImg2ImgPipeline),
170170
("flux", FluxImg2ImgPipeline),
171171
("flux-controlnet", FluxControlNetImg2ImgPipeline),

src/diffusers/utils/dummy_torch_and_transformers_objects.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,21 @@ def from_pretrained(cls, *args, **kwargs):
257257
requires_backends(cls, ["torch", "transformers"])
258258

259259

260+
class AuraFlowImg2ImgPipeline(metaclass=DummyObject):
261+
_backends = ["torch", "transformers"]
262+
263+
def __init__(self, *args, **kwargs):
264+
requires_backends(self, ["torch", "transformers"])
265+
266+
@classmethod
267+
def from_config(cls, *args, **kwargs):
268+
requires_backends(cls, ["torch", "transformers"])
269+
270+
@classmethod
271+
def from_pretrained(cls, *args, **kwargs):
272+
requires_backends(cls, ["torch", "transformers"])
273+
274+
260275
class AuraFlowPipeline(metaclass=DummyObject):
261276
_backends = ["torch", "transformers"]
262277

tests/pipelines/aura_flow/test_pipeline_aura_flow.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,3 +135,6 @@ def test_fused_qkv_projections(self):
135135
@unittest.skip("xformers attention processor does not exist for AuraFlow")
136136
def test_xformers_attention_forwardGenerator_pass(self):
137137
pass
138+
139+
def test_inference_batch_single_identical(self, batch_size=3, expected_max_diff=0.0004):
140+
self._test_inference_batch_single_identical(batch_size=batch_size, expected_max_diff=expected_max_diff)

0 commit comments

Comments
 (0)