Skip to content
32 changes: 20 additions & 12 deletions src/diffusers/models/transformers/auraflow_transformer_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,15 +74,23 @@ def pe_selection_index_based_on_dim(self, h, w):
# PE will be viewed as 2d-grid, and H/p x W/p of the PE will be selected
# because original input are in flattened format, we have to flatten this 2d grid as well.
h_p, w_p = h // self.patch_size, w // self.patch_size
original_pe_indexes = torch.arange(self.pos_embed.shape[1])
h_max, w_max = int(self.pos_embed_max_size**0.5), int(self.pos_embed_max_size**0.5)
original_pe_indexes = original_pe_indexes.view(h_max, w_max)

# Calculate the top-left corner indices for the centered patch grid
starth = h_max // 2 - h_p // 2
endh = starth + h_p
startw = w_max // 2 - w_p // 2
endw = startw + w_p
original_pe_indexes = original_pe_indexes[starth:endh, startw:endw]
return original_pe_indexes.flatten()

# Generate the row and column indices for the desired patch grid
rows = torch.arange(starth, starth + h_p, device=self.pos_embed.device)
cols = torch.arange(startw, startw + w_p, device=self.pos_embed.device)

# Create a 2D grid of indices
row_indices, col_indices = torch.meshgrid(rows, cols, indexing="ij")

# Convert the 2D grid indices to flattened 1D indices
selected_indices = (row_indices * w_max + col_indices).flatten()

return selected_indices

def forward(self, latent):
batch_size, num_channels, height, width = latent.size()
Expand Down Expand Up @@ -275,17 +283,17 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, From
sample_size (`int`): The width of the latent images. This is fixed during training since
it is used to learn a number of position embeddings.
patch_size (`int`): Patch size to turn the input data into small patches.
in_channels (`int`, *optional*, defaults to 16): The number of channels in the input.
in_channels (`int`, *optional*, defaults to 4): The number of channels in the input.
num_mmdit_layers (`int`, *optional*, defaults to 4): The number of layers of MMDiT Transformer blocks to use.
num_single_dit_layers (`int`, *optional*, defaults to 4):
num_single_dit_layers (`int`, *optional*, defaults to 32):
The number of layers of Transformer blocks to use. These blocks use concatenated image and text
representations.
attention_head_dim (`int`, *optional*, defaults to 64): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 18): The number of heads to use for multi-head attention.
attention_head_dim (`int`, *optional*, defaults to 256): The number of channels in each head.
num_attention_heads (`int`, *optional*, defaults to 12): The number of heads to use for multi-head attention.
joint_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
caption_projection_dim (`int`): Number of dimensions to use when projecting the `encoder_hidden_states`.
out_channels (`int`, defaults to 16): Number of output channels.
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
out_channels (`int`, defaults to 4): Number of output channels.
pos_embed_max_size (`int`, defaults to 1024): Maximum positions to embed from the image latents.
"""

_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
Expand Down