Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
162 changes: 99 additions & 63 deletions src/diffusers/models/transformers/transformer_hidream_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ...loaders import PeftAdapterMixin
from ...models.modeling_outputs import Transformer2DModelOutput
from ...models.modeling_utils import ModelMixin
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
from ...utils.torch_utils import maybe_allow_in_graph
from ..attention import Attention
from ..embeddings import TimestepEmbedding, Timesteps
Expand Down Expand Up @@ -686,46 +686,106 @@ def unpatchify(self, x: torch.Tensor, img_sizes: List[Tuple[int, int]], is_train
x = torch.cat(x_arr, dim=0)
return x

def patchify(self, x, max_seq, img_sizes=None):
pz2 = self.config.patch_size * self.config.patch_size
if isinstance(x, torch.Tensor):
B, C = x.shape[0], x.shape[1]
device = x.device
dtype = x.dtype
else:
B, C = len(x), x[0].shape[0]
device = x[0].device
dtype = x[0].dtype
x_masks = torch.zeros((B, max_seq), dtype=dtype, device=device)
def patchify(self, hidden_states):
batch_size, channels, height, width = hidden_states.shape
patch_size = self.config.patch_size
patch_height, patch_width = height // patch_size, width // patch_size
device = hidden_states.device
dtype = hidden_states.dtype

if img_sizes is not None:
for i, img_size in enumerate(img_sizes):
x_masks[i, 0 : img_size[0] * img_size[1]] = 1
B, C, S, _ = x.shape
x = x.permute(0, 2, 3, 1).reshape(B, S, pz2 * C)
elif isinstance(x, torch.Tensor):
B, C, Hp1, Wp2 = x.shape
pH, pW = Hp1 // self.config.patch_size, Wp2 // self.config.patch_size
x = x.reshape(B, C, pH, self.config.patch_size, pW, self.config.patch_size)
x = x.permute(0, 2, 4, 3, 5, 1)
x = x.reshape(B, pH * pW, self.config.patch_size * self.config.patch_size * C)
img_sizes = [[pH, pW]] * B
x_masks = None
hidden_states_masks = torch.zeros((batch_size, self.max_seq), dtype=dtype, device=device)

if hidden_states.shape[-2] != hidden_states.shape[-1]:
# Handle non-square latents
out = torch.zeros(
(batch_size, channels, self.max_seq, patch_size * patch_size),
dtype=dtype,
device=device,
)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height * patch_width, patch_size * patch_size
)
out[:, :, 0 : patch_height * patch_width] = hidden_states
hidden_states = out

img_sizes = torch.tensor([patch_height, patch_width], dtype=torch.int64, device=device).reshape(-1)
img_ids = torch.zeros(patch_height, patch_width, 3, device=device)

row_indices = torch.arange(patch_height, device=device)[:, None]
col_indices = torch.arange(patch_width, device=device)[None, :]

img_ids[..., 1] = img_ids[..., 1] + row_indices
img_ids[..., 2] = img_ids[..., 2] + col_indices

img_ids = img_ids.reshape(patch_height * patch_width, -1)
img_ids_pad = torch.zeros(self.max_seq, 3, device=device)
img_ids_pad[: patch_height * patch_width, :] = img_ids

img_sizes = img_sizes.unsqueeze(0).repeat(batch_size, 1)
img_ids = img_ids_pad.unsqueeze(0).repeat(batch_size, 1, 1)

hidden_states_masks[:, : patch_height * patch_width] = 1.0

hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(
batch_size, self.max_seq, patch_size * patch_size * channels
)
else:
raise NotImplementedError
return x, x_masks, img_sizes
# Handle square latents
hidden_states = hidden_states.reshape(
batch_size, channels, patch_height, patch_size, patch_width, patch_size
)
hidden_states = hidden_states.permute(0, 2, 4, 3, 5, 1)
hidden_states = hidden_states.reshape(
batch_size, patch_height * patch_width, patch_size * patch_size * channels
)
img_sizes = [[patch_height, patch_width]] * batch_size
hidden_states_masks = None

img_ids = torch.zeros(patch_height, patch_width, 3, device=device)

row_indices = torch.arange(patch_height, device=device)[:, None]
col_indices = torch.arange(patch_width, device=device)[None, :]
img_ids[..., 1] = img_ids[..., 1] + row_indices
img_ids[..., 2] = img_ids[..., 2] + col_indices

img_ids = (
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)

return hidden_states, hidden_states_masks, img_sizes, img_ids

def forward(
self,
hidden_states: torch.Tensor,
timesteps: torch.LongTensor = None,
encoder_hidden_states: torch.Tensor = None,
encoder_hidden_states_t5: torch.Tensor = None,
encoder_hidden_states_llama3: torch.Tensor = None,
pooled_embeds: torch.Tensor = None,
img_sizes: Optional[List[Tuple[int, int]]] = None,
img_ids: Optional[torch.Tensor] = None,
attention_kwargs: Optional[Dict[str, Any]] = None,
return_dict: bool = True,
**kwargs,
):
encoder_hidden_states = kwargs.get("encoder_hidden_states", None)
img_ids = kwargs.get("img_ids", None)
img_sizes = kwargs.get("img_sizes", None)
if encoder_hidden_states is not None:
deprecation_message = "The `encoder_hidden_states` argument is deprecated. Please use `encoder_hidden_states_t5` and `encoder_hidden_states_llama3` instead."
deprecate("encoder_hidden_states", "0.34.0", deprecation_message)
encoder_hidden_states_t5 = encoder_hidden_states[0]
encoder_hidden_states_llama3 = encoder_hidden_states[1]
if img_ids is not None:
deprecation_message = "The `img_ids` argument is deprecated and will be ignored."
deprecate("img_ids", "0.34.0", deprecation_message)
if img_sizes is not None:
deprecation_message = "The `img_sizes` argument is deprecated and will be ignored."
deprecate("img_sizes", "0.34.0", deprecation_message)

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
Expand All @@ -745,42 +805,18 @@ def forward(
batch_size = hidden_states.shape[0]
hidden_states_type = hidden_states.dtype

if hidden_states.shape[-2] != hidden_states.shape[-1]:
B, C, H, W = hidden_states.shape
patch_size = self.config.patch_size
pH, pW = H // patch_size, W // patch_size
out = torch.zeros(
(B, C, self.max_seq, patch_size * patch_size),
dtype=hidden_states.dtype,
device=hidden_states.device,
)
hidden_states = hidden_states.reshape(B, C, pH, patch_size, pW, patch_size)
hidden_states = hidden_states.permute(0, 1, 2, 4, 3, 5)
hidden_states = hidden_states.reshape(B, C, pH * pW, patch_size * patch_size)
out[:, :, 0 : pH * pW] = hidden_states
hidden_states = out
# Patchify the input
hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)
Copy link

@YehLi YehLi Apr 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    # Patchify the input
    if img_sizes is not None and img_ids is not None:
        B, C, S, _ = hidden_states.shape
        hidden_states_masks = torch.zeros((B, self.max_seq), dtype=hidden_states.dtype, device=hidden_states.device)
        for i, img_size in enumerate(img_sizes):
            hidden_states_masks[i, 0:img_size[0] * img_size[1]] = 1
        hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(B, S, self.config.patch_size * self.config.patch_size * C)
    else:
        hidden_states, hidden_states_masks, img_sizes, img_ids = self.patchify(hidden_states)

Keep img_sizes and img_ids since samples with different aspect ratios would be contained in the same batch during training

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks! I made it so that the patchify step can be skipped, i.e. for samples with different aspect ratios, user would have to prepare patchified hidden_states, hidden_states_mask, img_sizes and img_ids outside of the model and pass them as inputs, I think it is easier this way:

  1. this library is mainly for inference and fine tune, I think training with different aspect ratios is most for pre-trainning, no?
  2. they would need to prepare latents differently, we currently do not support that anyway

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. During pre-trainning, the input images are not cropped by default. Thus, it is recommended that users do not apply any cropping for images and prepare patchified hidden_states, img_sizes and img_ids outside of the model during fine-tuning. It is fine to finetune the model on multi-aspect bucketing like SDXL. The main purpose for img_sizes and img_ids is to provide users more options if they do not want to crop the images.
  2. They can prepare latents on dataloader with shape (B, C, HW / patch_size / patch_size, patch_sizepatch_size) and then pad them into the same length (B, C, S, patch_size*patch_size)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @a-r-r-o-w @sayakpaul @linoytsaban here in case we want to try this out for training script

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks @YehLi !


# Embed the hidden states
hidden_states = self.x_embedder(hidden_states)

# 0. time
timesteps = self.t_embedder(timesteps, hidden_states_type)
p_embedder = self.p_embedder(pooled_embeds)
temb = timesteps + p_embedder

hidden_states, hidden_states_masks, img_sizes = self.patchify(hidden_states, self.max_seq, img_sizes)
if hidden_states_masks is None:
pH, pW = img_sizes[0]
img_ids = torch.zeros(pH, pW, 3, device=hidden_states.device)
img_ids[..., 1] = img_ids[..., 1] + torch.arange(pH, device=hidden_states.device)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.arange(pW, device=hidden_states.device)[None, :]
img_ids = (
img_ids.reshape(img_ids.shape[0] * img_ids.shape[1], img_ids.shape[2])
.unsqueeze(0)
.repeat(batch_size, 1, 1)
)
hidden_states = self.x_embedder(hidden_states)

T5_encoder_hidden_states = encoder_hidden_states[0]
encoder_hidden_states = encoder_hidden_states[-1]
encoder_hidden_states = [encoder_hidden_states[k] for k in self.config.llama_layers]
encoder_hidden_states = [encoder_hidden_states_llama3[k] for k in self.config.llama_layers]

if self.caption_projection is not None:
new_encoder_hidden_states = []
Expand All @@ -789,9 +825,9 @@ def forward(
enc_hidden_state = enc_hidden_state.view(batch_size, -1, hidden_states.shape[-1])
new_encoder_hidden_states.append(enc_hidden_state)
encoder_hidden_states = new_encoder_hidden_states
T5_encoder_hidden_states = self.caption_projection[-1](T5_encoder_hidden_states)
T5_encoder_hidden_states = T5_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(T5_encoder_hidden_states)
encoder_hidden_states_t5 = self.caption_projection[-1](encoder_hidden_states_t5)
encoder_hidden_states_t5 = encoder_hidden_states_t5.view(batch_size, -1, hidden_states.shape[-1])
encoder_hidden_states.append(encoder_hidden_states_t5)

txt_ids = torch.zeros(
batch_size,
Expand Down
Loading
Loading