Skip to content

Commit f0e21a9

Browse files
authored
Merge branch 'main' into dduf
2 parents 627aec0 + 8eb73c8 commit f0e21a9

File tree

9 files changed

+85
-9
lines changed

9 files changed

+85
-9
lines changed

examples/community/pipeline_hunyuandit_differential_img2img.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,8 @@ def __call__(
10081008
self.transformer.inner_dim // self.transformer.num_heads,
10091009
grid_crops_coords,
10101010
(grid_height, grid_width),
1011+
device=device,
1012+
output_type="pt",
10111013
)
10121014

10131015
style = torch.tensor([0], device=device)
File renamed without changes.

src/diffusers/models/attention.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -188,8 +188,13 @@ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
188188
self._chunk_dim = dim
189189

190190
def forward(
191-
self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
191+
self,
192+
hidden_states: torch.FloatTensor,
193+
encoder_hidden_states: torch.FloatTensor,
194+
temb: torch.FloatTensor,
195+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
192196
):
197+
joint_attention_kwargs = joint_attention_kwargs or {}
193198
if self.use_dual_attention:
194199
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2 = self.norm1(
195200
hidden_states, emb=temb
@@ -206,15 +211,17 @@ def forward(
206211

207212
# Attention.
208213
attn_output, context_attn_output = self.attn(
209-
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
214+
hidden_states=norm_hidden_states,
215+
encoder_hidden_states=norm_encoder_hidden_states,
216+
**joint_attention_kwargs,
210217
)
211218

212219
# Process attention outputs for the `hidden_states`.
213220
attn_output = gate_msa.unsqueeze(1) * attn_output
214221
hidden_states = hidden_states + attn_output
215222

216223
if self.use_dual_attention:
217-
attn_output2 = self.attn2(hidden_states=norm_hidden_states2)
224+
attn_output2 = self.attn2(hidden_states=norm_hidden_states2, **joint_attention_kwargs)
218225
attn_output2 = gate_msa2.unsqueeze(1) * attn_output2
219226
hidden_states = hidden_states + attn_output2
220227

src/diffusers/models/embeddings.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,57 @@ def get_3d_rotary_pos_embed_allegro(
957957
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
958958

959959

960-
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
960+
def get_2d_rotary_pos_embed(
961+
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
962+
):
963+
"""
964+
RoPE for image tokens with 2d structure.
965+
966+
Args:
967+
embed_dim: (`int`):
968+
The embedding dimension size
969+
crops_coords (`Tuple[int]`)
970+
The top-left and bottom-right coordinates of the crop.
971+
grid_size (`Tuple[int]`):
972+
The grid size of the positional embedding.
973+
use_real (`bool`):
974+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
975+
device: (`torch.device`, **optional**):
976+
The device used to create tensors.
977+
978+
Returns:
979+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
980+
"""
981+
if output_type == "np":
982+
deprecation_message = (
983+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
984+
" `from_numpy` is no longer required."
985+
" Pass `output_type='pt' to use the new version now."
986+
)
987+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
988+
return _get_2d_rotary_pos_embed_np(
989+
embed_dim=embed_dim,
990+
crops_coords=crops_coords,
991+
grid_size=grid_size,
992+
use_real=use_real,
993+
)
994+
start, stop = crops_coords
995+
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
996+
grid_h = torch.linspace(
997+
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
998+
)
999+
grid_w = torch.linspace(
1000+
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
1001+
)
1002+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
1003+
grid = torch.stack(grid, dim=0) # [2, W, H]
1004+
1005+
grid = grid.reshape([2, 1, *grid.shape[1:]])
1006+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
1007+
return pos_embed
1008+
1009+
1010+
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
9611011
"""
9621012
RoPE for image tokens with 2d structure.
9631013

src/diffusers/models/transformers/transformer_sd3.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -411,11 +411,15 @@ def custom_forward(*inputs):
411411
hidden_states,
412412
encoder_hidden_states,
413413
temb,
414+
joint_attention_kwargs,
414415
**ckpt_kwargs,
415416
)
416417
elif not is_skip:
417418
encoder_hidden_states, hidden_states = block(
418-
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
419+
hidden_states=hidden_states,
420+
encoder_hidden_states=encoder_hidden_states,
421+
temb=temb,
422+
joint_attention_kwargs=joint_attention_kwargs,
419423
)
420424

421425
# controlnet residual

src/diffusers/models/unets/unet_2d.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,7 @@ def __init__(
9797
out_channels: int = 3,
9898
center_input_sample: bool = False,
9999
time_embedding_type: str = "positional",
100+
time_embedding_dim: Optional[int] = None,
100101
freq_shift: int = 0,
101102
flip_sin_to_cos: bool = True,
102103
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
@@ -122,7 +123,7 @@ def __init__(
122123
super().__init__()
123124

124125
self.sample_size = sample_size
125-
time_embed_dim = block_out_channels[0] * 4
126+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
126127

127128
# Check inputs
128129
if len(down_block_types) != len(up_block_types):

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,11 @@ def __call__(
925925
base_size = 512 // 8 // self.transformer.config.patch_size
926926
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
927927
image_rotary_emb = get_2d_rotary_pos_embed(
928-
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
928+
self.transformer.inner_dim // self.transformer.num_heads,
929+
grid_crops_coords,
930+
(grid_height, grid_width),
931+
device=device,
932+
output_type="pt",
929933
)
930934

931935
style = torch.tensor([0], device=device)

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,11 @@ def __call__(
798798
base_size = 512 // 8 // self.transformer.config.patch_size
799799
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
800800
image_rotary_emb = get_2d_rotary_pos_embed(
801-
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
801+
self.transformer.inner_dim // self.transformer.num_heads,
802+
grid_crops_coords,
803+
(grid_height, grid_width),
804+
device=device,
805+
output_type="pt",
802806
)
803807

804808
style = torch.tensor([0], device=device)

src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,11 @@ def __call__(
818818
base_size = 512 // 8 // self.transformer.config.patch_size
819819
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
820820
image_rotary_emb = get_2d_rotary_pos_embed(
821-
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
821+
self.transformer.inner_dim // self.transformer.num_heads,
822+
grid_crops_coords,
823+
(grid_height, grid_width),
824+
device=device,
825+
output_type="pt",
822826
)
823827

824828
style = torch.tensor([0], device=device)

0 commit comments

Comments
 (0)