Skip to content

Commit 4c4b323

Browse files
authored
Use torch in get_3d_rotary_pos_embed/_allegro (huggingface#10161)
Use torch in get_3d_rotary_pos_embed/_allegro
1 parent 22d3a82 commit 4c4b323

File tree

8 files changed

+41
-32
lines changed

8 files changed

+41
-32
lines changed

examples/cogvideo/train_cogvideox_image_to_video_lora.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -872,10 +872,9 @@ def prepare_rotary_positional_embeddings(
872872
crops_coords=grid_crops_coords,
873873
grid_size=(grid_height, grid_width),
874874
temporal_size=num_frames,
875+
device=device,
875876
)
876877

877-
freqs_cos = freqs_cos.to(device=device)
878-
freqs_sin = freqs_sin.to(device=device)
879878
return freqs_cos, freqs_sin
880879

881880

examples/cogvideo/train_cogvideox_lora.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -894,10 +894,9 @@ def prepare_rotary_positional_embeddings(
894894
crops_coords=grid_crops_coords,
895895
grid_size=(grid_height, grid_width),
896896
temporal_size=num_frames,
897+
device=device,
897898
)
898899

899-
freqs_cos = freqs_cos.to(device=device)
900-
freqs_sin = freqs_sin.to(device=device)
901900
return freqs_cos, freqs_sin
902901

903902

src/diffusers/models/embeddings.py

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ def get_3d_rotary_pos_embed(
594594
use_real: bool = True,
595595
grid_type: str = "linspace",
596596
max_size: Optional[Tuple[int, int]] = None,
597+
device: Optional[torch.device] = None,
597598
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
598599
"""
599600
RoPE for video tokens with 3D structure.
@@ -621,16 +622,22 @@ def get_3d_rotary_pos_embed(
621622
if grid_type == "linspace":
622623
start, stop = crops_coords
623624
grid_size_h, grid_size_w = grid_size
624-
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
625-
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
626-
grid_t = np.arange(temporal_size, dtype=np.float32)
627-
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
625+
grid_h = torch.linspace(
626+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
627+
)
628+
grid_w = torch.linspace(
629+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
630+
)
631+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
632+
grid_t = torch.linspace(
633+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
634+
)
628635
elif grid_type == "slice":
629636
max_h, max_w = max_size
630637
grid_size_h, grid_size_w = grid_size
631-
grid_h = np.arange(max_h, dtype=np.float32)
632-
grid_w = np.arange(max_w, dtype=np.float32)
633-
grid_t = np.arange(temporal_size, dtype=np.float32)
638+
grid_h = torch.arange(max_h, device=device, dtype=torch.float32)
639+
grid_w = torch.arange(max_w, device=device, dtype=torch.float32)
640+
grid_t = torch.arange(temporal_size, device=device, dtype=torch.float32)
634641
else:
635642
raise ValueError("Invalid value passed for `grid_type`.")
636643

@@ -640,10 +647,10 @@ def get_3d_rotary_pos_embed(
640647
dim_w = embed_dim // 8 * 3
641648

642649
# Temporal frequencies
643-
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, use_real=True)
650+
freqs_t = get_1d_rotary_pos_embed(dim_t, grid_t, theta=theta, use_real=True)
644651
# Spatial frequencies for height and width
645-
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, use_real=True)
646-
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, use_real=True)
652+
freqs_h = get_1d_rotary_pos_embed(dim_h, grid_h, theta=theta, use_real=True)
653+
freqs_w = get_1d_rotary_pos_embed(dim_w, grid_w, theta=theta, use_real=True)
647654

648655
# BroadCast and concatenate temporal and spaial frequencie (height and width) into a 3d tensor
649656
def combine_time_height_width(freqs_t, freqs_h, freqs_w):
@@ -686,14 +693,21 @@ def get_3d_rotary_pos_embed_allegro(
686693
temporal_size,
687694
interpolation_scale: Tuple[float, float, float] = (1.0, 1.0, 1.0),
688695
theta: int = 10000,
696+
device: Optional[torch.device] = None,
689697
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
690698
# TODO(aryan): docs
691699
start, stop = crops_coords
692700
grid_size_h, grid_size_w = grid_size
693701
interpolation_scale_t, interpolation_scale_h, interpolation_scale_w = interpolation_scale
694-
grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
695-
grid_h = np.linspace(start[0], stop[0], grid_size_h, endpoint=False, dtype=np.float32)
696-
grid_w = np.linspace(start[1], stop[1], grid_size_w, endpoint=False, dtype=np.float32)
702+
grid_t = torch.linspace(
703+
0, temporal_size * (temporal_size - 1) / temporal_size, temporal_size, device=device, dtype=torch.float32
704+
)
705+
grid_h = torch.linspace(
706+
start[0], stop[0] * (grid_size_h - 1) / grid_size_h, grid_size_h, device=device, dtype=torch.float32
707+
)
708+
grid_w = torch.linspace(
709+
start[1], stop[1] * (grid_size_w - 1) / grid_size_w, grid_size_w, device=device, dtype=torch.float32
710+
)
697711

698712
# Compute dimensions for each axis
699713
dim_t = embed_dim // 3

src/diffusers/pipelines/allegro/pipeline_allegro.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -623,20 +623,17 @@ def _prepare_rotary_positional_embeddings(
623623
self.transformer.config.interpolation_scale_h,
624624
self.transformer.config.interpolation_scale_w,
625625
),
626+
device=device,
626627
)
627628

628-
grid_t = torch.from_numpy(grid_t).to(device=device, dtype=torch.long)
629-
grid_h = torch.from_numpy(grid_h).to(device=device, dtype=torch.long)
630-
grid_w = torch.from_numpy(grid_w).to(device=device, dtype=torch.long)
629+
grid_t = grid_t.to(dtype=torch.long)
630+
grid_h = grid_h.to(dtype=torch.long)
631+
grid_w = grid_w.to(dtype=torch.long)
631632

632633
pos = torch.cartesian_prod(grid_t, grid_h, grid_w)
633634
pos = pos.reshape(-1, 3).transpose(0, 1).reshape(3, 1, -1).contiguous()
634635
grid_t, grid_h, grid_w = pos
635636

636-
freqs_t = (freqs_t[0].to(device=device), freqs_t[1].to(device=device))
637-
freqs_h = (freqs_h[0].to(device=device), freqs_h[1].to(device=device))
638-
freqs_w = (freqs_w[0].to(device=device), freqs_w[1].to(device=device))
639-
640637
return (freqs_t, freqs_h, freqs_w), (grid_t, grid_h, grid_w)
641638

642639
@property

src/diffusers/pipelines/cogvideo/pipeline_cogvideox.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def _prepare_rotary_positional_embeddings(
459459
crops_coords=grid_crops_coords,
460460
grid_size=(grid_height, grid_width),
461461
temporal_size=num_frames,
462+
device=device,
462463
)
463464
else:
464465
# CogVideoX 1.5
@@ -471,10 +472,9 @@ def _prepare_rotary_positional_embeddings(
471472
temporal_size=base_num_frames,
472473
grid_type="slice",
473474
max_size=(base_size_height, base_size_width),
475+
device=device,
474476
)
475477

476-
freqs_cos = freqs_cos.to(device=device)
477-
freqs_sin = freqs_sin.to(device=device)
478478
return freqs_cos, freqs_sin
479479

480480
@property

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -505,6 +505,7 @@ def _prepare_rotary_positional_embeddings(
505505
crops_coords=grid_crops_coords,
506506
grid_size=(grid_height, grid_width),
507507
temporal_size=num_frames,
508+
device=device,
508509
)
509510
else:
510511
# CogVideoX 1.5
@@ -517,10 +518,9 @@ def _prepare_rotary_positional_embeddings(
517518
temporal_size=base_num_frames,
518519
grid_type="slice",
519520
max_size=(base_size_height, base_size_width),
521+
device=device,
520522
)
521523

522-
freqs_cos = freqs_cos.to(device=device)
523-
freqs_sin = freqs_sin.to(device=device)
524524
return freqs_cos, freqs_sin
525525

526526
@property

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -555,6 +555,7 @@ def _prepare_rotary_positional_embeddings(
555555
crops_coords=grid_crops_coords,
556556
grid_size=(grid_height, grid_width),
557557
temporal_size=num_frames,
558+
device=device,
558559
)
559560
else:
560561
# CogVideoX 1.5
@@ -567,10 +568,9 @@ def _prepare_rotary_positional_embeddings(
567568
temporal_size=base_num_frames,
568569
grid_type="slice",
569570
max_size=(base_size_height, base_size_width),
571+
device=device,
570572
)
571573

572-
freqs_cos = freqs_cos.to(device=device)
573-
freqs_sin = freqs_sin.to(device=device)
574574
return freqs_cos, freqs_sin
575575

576576
@property

src/diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,6 +529,7 @@ def _prepare_rotary_positional_embeddings(
529529
crops_coords=grid_crops_coords,
530530
grid_size=(grid_height, grid_width),
531531
temporal_size=num_frames,
532+
device=device,
532533
)
533534
else:
534535
# CogVideoX 1.5
@@ -541,10 +542,9 @@ def _prepare_rotary_positional_embeddings(
541542
temporal_size=base_num_frames,
542543
grid_type="slice",
543544
max_size=(base_size_height, base_size_width),
545+
device=device,
544546
)
545547

546-
freqs_cos = freqs_cos.to(device=device)
547-
freqs_sin = freqs_sin.to(device=device)
548548
return freqs_cos, freqs_sin
549549

550550
@property

0 commit comments

Comments
 (0)