Skip to content

Commit f2e7731

Browse files
committed
Add deprecation
1 parent 6561450 commit f2e7731

File tree

5 files changed

+48
-1
lines changed

5 files changed

+48
-1
lines changed

examples/community/pipeline_hunyuandit_differential_img2img.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1009,6 +1009,7 @@ def __call__(
10091009
grid_crops_coords,
10101010
(grid_height, grid_width),
10111011
device=device,
1012+
output_type="pt",
10121013
)
10131014

10141015
style = torch.tensor([0], device=device)

src/diffusers/models/embeddings.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -957,7 +957,9 @@ def get_3d_rotary_pos_embed_allegro(
957957
return freqs_t, freqs_h, freqs_w, grid_t, grid_h, grid_w
958958

959959

960-
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None):
960+
def get_2d_rotary_pos_embed(
961+
embed_dim, crops_coords, grid_size, use_real=True, device: Optional[torch.device] = None, output_type: str = "np"
962+
):
961963
"""
962964
RoPE for image tokens with 2d structure.
963965
@@ -976,6 +978,19 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True, d
976978
Returns:
977979
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
978980
"""
981+
if output_type == "np":
982+
deprecation_message = (
983+
"`get_2d_sincos_pos_embed` uses `torch` and supports `device`."
984+
" `from_numpy` is no longer required."
985+
" Pass `output_type='pt' to use the new version now."
986+
)
987+
deprecate("output_type=='np'", "0.33.0", deprecation_message, standard_warn=False)
988+
return _get_2d_rotary_pos_embed_np(
989+
embed_dim=embed_dim,
990+
crops_coords=crops_coords,
991+
grid_size=grid_size,
992+
use_real=use_real,
993+
)
979994
start, stop = crops_coords
980995
# scale end by (steps−1)/steps matches np.linspace(..., endpoint=False)
981996
grid_h = torch.linspace(
@@ -992,6 +1007,34 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True, d
9921007
return pos_embed
9931008

9941009

1010+
def _get_2d_rotary_pos_embed_np(embed_dim, crops_coords, grid_size, use_real=True):
1011+
"""
1012+
RoPE for image tokens with 2d structure.
1013+
1014+
Args:
1015+
embed_dim: (`int`):
1016+
The embedding dimension size
1017+
crops_coords (`Tuple[int]`)
1018+
The top-left and bottom-right coordinates of the crop.
1019+
grid_size (`Tuple[int]`):
1020+
The grid size of the positional embedding.
1021+
use_real (`bool`):
1022+
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
1023+
1024+
Returns:
1025+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
1026+
"""
1027+
start, stop = crops_coords
1028+
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
1029+
grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
1030+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
1031+
grid = np.stack(grid, axis=0) # [2, W, H]
1032+
1033+
grid = grid.reshape([2, 1, *grid.shape[1:]])
1034+
pos_embed = get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=use_real)
1035+
return pos_embed
1036+
1037+
9951038
def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
9961039
"""
9971040
Get 2D RoPE from grid.

src/diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -929,6 +929,7 @@ def __call__(
929929
grid_crops_coords,
930930
(grid_height, grid_width),
931931
device=device,
932+
output_type="pt",
932933
)
933934

934935
style = torch.tensor([0], device=device)

src/diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -802,6 +802,7 @@ def __call__(
802802
grid_crops_coords,
803803
(grid_height, grid_width),
804804
device=device,
805+
output_type="pt",
805806
)
806807

807808
style = torch.tensor([0], device=device)

src/diffusers/pipelines/pag/pipeline_pag_hunyuandit.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ def __call__(
822822
grid_crops_coords,
823823
(grid_height, grid_width),
824824
device=device,
825+
output_type="pt",
825826
)
826827

827828
style = torch.tensor([0], device=device)

0 commit comments

Comments
 (0)