Skip to content

Commit 8c7c483

Browse files
committed
combine attention processor
1 parent 6f29e2a commit 8c7c483

File tree

2 files changed

+33
-66
lines changed

2 files changed

+33
-66
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 31 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -822,19 +822,27 @@ def __init__(
822822
self.to_out = nn.Linear(inner_dim * (1 + len(kernel_sizes)), out_channels, bias=False)
823823
self.norm_out = get_normalization(norm_type, num_features=out_channels)
824824

825-
self.processor = SanaMultiscaleLinearAttnProcessor2_0()
826-
self.processor_quadratic = SanaMultiscaleQuadraticAttnProcessor2_0()
825+
self.processor = SanaMultiscaleAttnProcessor2_0()
827826

828-
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
829-
height, width = hidden_states.shape[-2:]
827+
def apply_linear_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
828+
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1) # Adds padding
829+
scores = torch.matmul(value, key.transpose(-1, -2))
830+
hidden_states = torch.matmul(scores, query)
830831

831-
if height * width > self.attention_head_dim:
832-
hidden_states = self.processor(self, hidden_states)
833-
else:
834-
hidden_states = self.processor_quadratic(self, hidden_states)
832+
hidden_states = hidden_states.to(dtype=torch.float32)
833+
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + self.eps)
834+
return hidden_states
835835

836+
def apply_quadratic_attention(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor) -> torch.Tensor:
837+
scores = torch.matmul(key.transpose(-1, -2), query)
838+
scores = scores.to(dtype=torch.float32)
839+
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
840+
hidden_states = torch.matmul(value, scores)
836841
return hidden_states
837842

843+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
844+
return self.processor(self, hidden_states)
845+
838846

839847
class AttnProcessor:
840848
r"""
@@ -5089,65 +5097,18 @@ def __call__(
50895097
return hidden_states
50905098

50915099

5092-
class SanaMultiscaleLinearAttnProcessor2_0:
5100+
class SanaMultiscaleAttnProcessor2_0:
50935101
r"""
5094-
Processor for implementing multiscale linear attention.
5102+
Processor for implementing multiscale quadratic attention.
50955103
"""
50965104

50975105
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
5098-
residual = hidden_states
5099-
5100-
batch_size, _, height, width = hidden_states.shape
5101-
original_dtype = hidden_states.dtype
5102-
5103-
hidden_states = hidden_states.movedim(1, -1)
5104-
query = attn.to_q(hidden_states)
5105-
key = attn.to_k(hidden_states)
5106-
value = attn.to_v(hidden_states)
5107-
hidden_states = torch.cat([query, key, value], dim=3)
5108-
hidden_states = hidden_states.movedim(-1, 1)
5109-
5110-
multiscale_hidden_states = [hidden_states]
5111-
for block in attn.to_qkv_multiscale:
5112-
multiscale_hidden_states.append(block(hidden_states))
5113-
5114-
hidden_states = torch.cat(multiscale_hidden_states, dim=1)
5115-
5116-
hidden_states = hidden_states.to(dtype=torch.float32)
5117-
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
5118-
5119-
query, key, value = hidden_states.chunk(3, dim=2)
5120-
query = attn.nonlinearity(query)
5121-
key = attn.nonlinearity(key)
5122-
value = F.pad(value, (0, 0, 0, 1), mode="constant", value=1)
5123-
5124-
scores = torch.matmul(value, key.transpose(-1, -2))
5125-
hidden_states = torch.matmul(scores, query)
5126-
5127-
hidden_states = hidden_states.to(dtype=torch.float32)
5128-
hidden_states = hidden_states[:, :, :-1] / (hidden_states[:, :, -1:] + attn.eps)
5129-
hidden_states = hidden_states.to(dtype=original_dtype)
5130-
5131-
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
5132-
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5133-
5134-
if attn.norm_type == "rms_norm":
5135-
hidden_states = attn.norm_out(hidden_states.movedim(1, -1)).movedim(-1, 1)
5106+
height, width = hidden_states.shape[-2:]
5107+
if height * width > attn.attention_head_dim:
5108+
use_linear_attention = True
51365109
else:
5137-
hidden_states = attn.norm_out(hidden_states)
5110+
use_linear_attention = False
51385111

5139-
if attn.residual_connection:
5140-
hidden_states = hidden_states + residual
5141-
5142-
return hidden_states
5143-
5144-
5145-
class SanaMultiscaleQuadraticAttnProcessor2_0:
5146-
r"""
5147-
Processor for implementing multiscale quadratic attention.
5148-
"""
5149-
5150-
def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Tensor) -> torch.Tensor:
51515112
residual = hidden_states
51525113

51535114
batch_size, _, height, width = list(hidden_states.size())
@@ -5166,17 +5127,21 @@ def __call__(self, attn: SanaMultiscaleLinearAttention, hidden_states: torch.Ten
51665127

51675128
hidden_states = torch.cat(multi_scale_qkv, dim=1)
51685129

5130+
if use_linear_attention:
5131+
# for linear attention upcast hidden_states to float32
5132+
hidden_states = hidden_states.to(dtype=torch.float32)
5133+
51695134
hidden_states = hidden_states.reshape(batch_size, -1, 3 * attn.attention_head_dim, height * width)
51705135

51715136
query, key, value = hidden_states.chunk(3, dim=2)
51725137
query = attn.nonlinearity(query)
51735138
key = attn.nonlinearity(key)
51745139

5175-
scores = torch.matmul(key.transpose(-1, -2), query)
5176-
scores = scores.to(dtype=torch.float32)
5177-
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + attn.eps)
5178-
scores = scores.to(dtype=original_dtype)
5179-
hidden_states = torch.matmul(value, scores)
5140+
if use_linear_attention:
5141+
hidden_states = attn.apply_linear_attention(query, key, value)
5142+
hidden_states = hidden_states.to(dtype=original_dtype)
5143+
else:
5144+
hidden_states = attn.apply_quadratic_attention(query, key, value)
51805145

51815146
hidden_states = torch.reshape(hidden_states, (batch_size, -1, height, width))
51825147
hidden_states = attn.to_out(hidden_states.movedim(1, -1)).movedim(-1, 1)

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
5353
hidden_states = hidden_states * self.nonlinearity(gate)
5454

5555
hidden_states = self.conv_point(hidden_states)
56+
# move channel to the last dimension so we apply RMSnorm across channel dimension
5657
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
5758

5859
return hidden_states + residual
@@ -82,6 +83,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
8283
hidden_states = self.conv2(hidden_states)
8384

8485
if self.norm_type == "rms_norm":
86+
# move channel to the last dimension so we apply RMSnorm across channel dimension
8587
hidden_states = self.norm(hidden_states.movedim(1, -1)).movedim(-1, 1)
8688
else:
8789
hidden_states = self.norm(hidden_states)

0 commit comments

Comments
 (0)