Skip to content

Commit 3ab6db7

Browse files
committed
pag tests and refactor
1 parent 1ab6ab2 commit 3ab6db7

File tree

3 files changed

+35
-142
lines changed

3 files changed

+35
-142
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 23 additions & 119 deletions
Original file line numberDiff line numberDiff line change
@@ -5410,95 +5410,45 @@ def __call__(
54105410
hidden_states: torch.Tensor,
54115411
encoder_hidden_states: Optional[torch.Tensor] = None,
54125412
attention_mask: Optional[torch.Tensor] = None,
5413-
temb: Optional[torch.Tensor] = None,
5414-
*args,
5415-
**kwargs,
54165413
) -> torch.Tensor:
5417-
if len(args) > 0 or kwargs.get("scale", None) is not None:
5418-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
5419-
deprecate("scale", "1.0.0", deprecation_message)
5420-
5421-
residual = hidden_states
5422-
if attn.spatial_norm is not None:
5423-
hidden_states = attn.spatial_norm(hidden_states, temb)
5424-
5425-
input_ndim = hidden_states.ndim
5426-
5427-
if input_ndim == 4:
5428-
batch_size, channel, height, width = hidden_states.shape
5429-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
5414+
original_dtype = hidden_states.dtype
54305415

5431-
# chunk
54325416
hidden_states_uncond, hidden_states_org, hidden_states_ptb = hidden_states.chunk(3)
54335417
hidden_states_org = torch.cat([hidden_states_uncond, hidden_states_org])
54345418

5435-
# original path
5436-
batch_size, sequence_length, _ = (
5437-
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
5438-
)
5439-
54405419
query = attn.to_q(hidden_states_org)
54415420
key = attn.to_k(hidden_states_org)
54425421
value = attn.to_v(hidden_states_org)
54435422

5444-
inner_dim = key.shape[-1]
5445-
head_dim = inner_dim // attn.heads
5446-
5447-
dtype = query.dtype
5448-
5449-
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
5450-
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
5451-
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
5423+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5424+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5425+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
54525426

5453-
query = self.kernel_func(query) # B, h, h_d, N
5427+
query = self.kernel_func(query)
54545428
key = self.kernel_func(key)
54555429

5456-
# need torch.float
54575430
query, key, value = query.float(), key.float(), value.float()
54585431

54595432
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
5460-
vk = torch.matmul(value, key)
5461-
hidden_states_org = torch.matmul(vk, query)
5433+
scores = torch.matmul(value, key)
5434+
hidden_states_org = torch.matmul(scores, query)
54625435

5463-
if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
5464-
hidden_states_org = hidden_states_org.float()
54655436
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + self.eps)
5437+
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
5438+
hidden_states_org = hidden_states_org.to(original_dtype)
54665439

5467-
hidden_states_org = hidden_states_org.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
5468-
hidden_states_org = hidden_states_org.to(dtype)
5469-
5470-
# linear proj
54715440
hidden_states_org = attn.to_out[0](hidden_states_org)
5472-
# dropout
54735441
hidden_states_org = attn.to_out[1](hidden_states_org)
54745442

5475-
if input_ndim == 4:
5476-
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
5477-
54785443
# perturbed path (identity attention)
5479-
batch_size, sequence_length, _ = hidden_states_ptb.shape
5444+
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
54805445

5481-
value = attn.to_v(hidden_states_ptb)
5482-
hidden_states_ptb = value
5483-
hidden_states_ptb = hidden_states_ptb.to(dtype)
5484-
5485-
# linear proj
54865446
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
5487-
# dropout
54885447
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
54895448

5490-
if input_ndim == 4:
5491-
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
5492-
5493-
# cat
54945449
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
54955450

5496-
if attn.residual_connection:
5497-
hidden_states = hidden_states + residual
5498-
5499-
hidden_states = hidden_states / attn.rescale_output_factor
5500-
5501-
if hidden_states.dtype == torch.float16:
5451+
if original_dtype == torch.float16:
55025452
hidden_states = hidden_states.clip(-65504, 65504)
55035453

55045454
return hidden_states
@@ -5520,93 +5470,47 @@ def __call__(
55205470
hidden_states: torch.Tensor,
55215471
encoder_hidden_states: Optional[torch.Tensor] = None,
55225472
attention_mask: Optional[torch.Tensor] = None,
5523-
temb: Optional[torch.Tensor] = None,
5524-
*args,
5525-
**kwargs,
55265473
) -> torch.Tensor:
5527-
if len(args) > 0 or kwargs.get("scale", None) is not None:
5528-
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
5529-
deprecate("scale", "1.0.0", deprecation_message)
5530-
5531-
residual = hidden_states
5532-
if attn.spatial_norm is not None:
5533-
hidden_states = attn.spatial_norm(hidden_states, temb)
5534-
5535-
input_ndim = hidden_states.ndim
5536-
5537-
if input_ndim == 4:
5538-
batch_size, channel, height, width = hidden_states.shape
5539-
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
5474+
original_dtype = hidden_states.dtype
55405475

5541-
# chunk
55425476
hidden_states_org, hidden_states_ptb = hidden_states.chunk(2)
55435477

5544-
# original path
5545-
batch_size, sequence_length, _ = (
5546-
hidden_states_org.shape if encoder_hidden_states is None else encoder_hidden_states.shape
5547-
)
5548-
55495478
query = attn.to_q(hidden_states_org)
55505479
key = attn.to_k(hidden_states_org)
55515480
value = attn.to_v(hidden_states_org)
55525481

5553-
inner_dim = key.shape[-1]
5554-
head_dim = inner_dim // attn.heads
5555-
5556-
dtype = query.dtype
5557-
5558-
query = query.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
5559-
key = key.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1).transpose(-1, -2)
5560-
value = value.transpose(-1, -2).reshape(batch_size, attn.heads, head_dim, -1)
5482+
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
5483+
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
5484+
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
55615485

5562-
query = self.kernel_func(query) # B, h, h_d, N
5486+
query = self.kernel_func(query)
55635487
key = self.kernel_func(key)
55645488

5565-
# need torch.float
55665489
query, key, value = query.float(), key.float(), value.float()
55675490

55685491
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=self.pad_val)
5569-
vk = torch.matmul(value, key)
5570-
hidden_states_org = torch.matmul(vk, query)
5492+
scores = torch.matmul(value, key)
5493+
hidden_states_org = torch.matmul(scores, query)
55715494

55725495
if hidden_states_org.dtype in [torch.float16, torch.bfloat16]:
55735496
hidden_states_org = hidden_states_org.float()
5574-
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + self.eps)
55755497

5576-
hidden_states_org = hidden_states_org.view(batch_size, attn.heads * head_dim, -1).permute(0, 2, 1)
5577-
hidden_states_org = hidden_states_org.to(dtype)
5498+
hidden_states_org = hidden_states_org[:, :, :-1] / (hidden_states_org[:, :, -1:] + self.eps)
5499+
hidden_states_org = hidden_states_org.flatten(1, 2).transpose(1, 2)
5500+
hidden_states_org = hidden_states_org.to(original_dtype)
55785501

5579-
# linear proj
55805502
hidden_states_org = attn.to_out[0](hidden_states_org)
5581-
# dropout
55825503
hidden_states_org = attn.to_out[1](hidden_states_org)
55835504

5584-
if input_ndim == 4:
5585-
hidden_states_org = hidden_states_org.transpose(-1, -2).reshape(batch_size, channel, height, width)
5586-
55875505
# perturbed path (identity attention)
5588-
batch_size, sequence_length, _ = hidden_states_ptb.shape
5506+
hidden_states_ptb = attn.to_v(hidden_states_ptb).to(original_dtype)
55895507

5590-
hidden_states_ptb = attn.to_v(hidden_states_ptb)
5591-
hidden_states_ptb = hidden_states_ptb.to(dtype)
5592-
5593-
# linear proj
55945508
hidden_states_ptb = attn.to_out[0](hidden_states_ptb)
5595-
# dropout
55965509
hidden_states_ptb = attn.to_out[1](hidden_states_ptb)
55975510

5598-
if input_ndim == 4:
5599-
hidden_states_ptb = hidden_states_ptb.transpose(-1, -2).reshape(batch_size, channel, height, width)
5600-
5601-
# cat
56025511
hidden_states = torch.cat([hidden_states_org, hidden_states_ptb])
56035512

5604-
if attn.residual_connection:
5605-
hidden_states = hidden_states + residual
5606-
5607-
hidden_states = hidden_states / attn.rescale_output_factor
5608-
5609-
if hidden_states.dtype == torch.float16:
5513+
if original_dtype == torch.float16:
56105514
hidden_states = hidden_states.clip(-65504, 65504)
56115515

56125516
return hidden_states

src/diffusers/pipelines/pag/pipeline_pag_sana.py

Lines changed: 8 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def __init__(
142142
vae: AutoencoderDC,
143143
transformer: SanaTransformer2DModel,
144144
scheduler: FlowDPMSolverMultistepScheduler,
145-
pag_applied_layers: Union[str, List[str]] = "transformer_blocks.8",
145+
pag_applied_layers: Union[str, List[str]] = "transformer_blocks.0",
146146
):
147147
super().__init__()
148148

@@ -511,8 +511,11 @@ def _clean_caption(self, caption):
511511

512512
return caption.strip()
513513

514-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
514+
# Copied from diffusers.pipelines.sana.pipeline_sana.SanaPipeline.prepare_latents
515515
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
516+
if latents is not None:
517+
return latents.to(device=device, dtype=dtype)
518+
516519
shape = (
517520
batch_size,
518521
num_channels_latents,
@@ -525,13 +528,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
525528
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
526529
)
527530

528-
if latents is None:
529-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
530-
else:
531-
latents = latents.to(device)
532-
533-
# scale the initial noise by the standard deviation required by the scheduler
534-
latents = latents * self.scheduler.init_noise_sigma
531+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
535532
return latents
536533

537534
@property
@@ -561,8 +558,8 @@ def __call__(
561558
sigmas: List[float] = None,
562559
guidance_scale: float = 4.5,
563560
num_images_per_prompt: Optional[int] = 1,
564-
height: Optional[int] = None,
565-
width: Optional[int] = None,
561+
height: int = 1024,
562+
width: int = 1024,
566563
eta: float = 0.0,
567564
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
568565
latents: Optional[torch.Tensor] = None,
@@ -771,9 +768,6 @@ def __call__(
771768
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
772769
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
773770

774-
# 6.1 Prepare micro-conditions.
775-
added_cond_kwargs = {"resolution": None, "aspect_ratio": None}
776-
777771
# 7. Denoising loop
778772
num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
779773
self._num_timesteps = len(timesteps)
@@ -796,7 +790,6 @@ def __call__(
796790
encoder_hidden_states=prompt_embeds,
797791
encoder_attention_mask=prompt_attention_mask,
798792
timestep=timestep,
799-
added_cond_kwargs=added_cond_kwargs,
800793
return_dict=False,
801794
)[0]
802795
noise_pred = noise_pred.float()

src/diffusers/pipelines/sana/pipeline_sana.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,10 @@ def _clean_caption(self, caption):
504504

505505
return caption.strip()
506506

507-
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
508507
def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
508+
if latents is not None:
509+
return latents.to(device=device, dtype=dtype)
510+
509511
shape = (
510512
batch_size,
511513
num_channels_latents,
@@ -518,13 +520,7 @@ def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype
518520
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
519521
)
520522

521-
if latents is None:
522-
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
523-
else:
524-
latents = latents.to(device)
525-
526-
# scale the initial noise by the standard deviation required by the scheduler
527-
latents = latents * self.scheduler.init_noise_sigma
523+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
528524
return latents
529525

530526
@property

0 commit comments

Comments
 (0)