Skip to content

Commit af5ecd9

Browse files
committed
Use torch in get_2d_rotary_pos_embed
1 parent 6131a93 commit af5ecd9

File tree

5 files changed

+25
-8
lines changed

5 files changed

+25
-8
lines changed

examples/community/pipeline_hunyuandit_differential_img2img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1008,6 +1008,7 @@ def __call__(
10081008
self.transformer.inner_dim // self.transformer.num_heads,
10091009
grid_crops_coords,
10101010
(grid_height, grid_width),
1011+
device=device,
10111012
)
10121013

10131014
style = torch.tensor([0], device=device)

src/diffusers/models/embeddings.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,7 @@ def get_3d_rotary_pos_embed_allegro(
715715
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
716716

717717

718-
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
718+
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None):
719719
"""
720720
RoPE for image tokens with 2d structure.
721721
@@ -728,15 +728,22 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
728728
The grid size of the positional embedding.
729729
use_real (`bool`):
730730
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
731+
device: (`torch.device`, **optional**):
732+
The device used to create tensors.
731733
732734
Returns:
733735
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
734736
"""
735737
start, stop = crops_coords
736-
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
737-
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
738-
grid = np.meshgrid(grid_w, grid_h) # here w goes first
739-
grid = np.stack(grid, axis=0) # [2, W, H]
738+
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
739+
grid_h = torch.linspace(
740+
start[0], stop[0] * (grid_size[0] - 1) / grid_size[0], grid_size[0], device=device, dtype=torch.float32
741+
)
742+
grid_w = torch.linspace(
743+
start[1], stop[1] * (grid_size[1] - 1) / grid_size[1], grid_size[1], device=device, dtype=torch.float32
744+
)
745+
grid = torch.meshgrid(grid_w, grid_h, indexing="xy")
746+
grid = torch.stack(grid, dim=0) # [2, W, H]
740747

741748
grid = grid.reshape([2, 1, *grid.shape[1:]])
742749
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -925,7 +925,10 @@ def __call__(
925925
base_size = 512 // 8 // self.transformer.config.patch_size
926926
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
927927
image_rotary_emb = get_2d_rotary_pos_embed(
928-
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
928+
self.transformer.inner_dim // self.transformer.num_heads,
929+
grid_crops_coords,
930+
(grid_height, grid_width),
931+
device=device,
929932
)
930933

931934
style = torch.tensor([0], device=device)

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -798,7 +798,10 @@ def __call__(
798798
base_size = 512 // 8 // self.transformer.config.patch_size
799799
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
800800
image_rotary_emb = get_2d_rotary_pos_embed(
801-
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
801+
self.transformer.inner_dim // self.transformer.num_heads,
802+
grid_crops_coords,
803+
(grid_height, grid_width),
804+
device=device,
802805
)
803806

804807
style = torch.tensor([0], device=device)

src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -818,7 +818,10 @@ def __call__(
818818
base_size = 512 // 8 // self.transformer.config.patch_size
819819
grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
820820
image_rotary_emb = get_2d_rotary_pos_embed(
821-
self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
821+
self.transformer.inner_dim // self.transformer.num_heads,
822+
grid_crops_coords,
823+
(grid_height, grid_width),
824+
device=device,
822825
)
823826

824827
style = torch.tensor([0], device=device)

0 commit comments

Comments
 (0)