Skip to content

Commit da420fb

Browse files
a-r-r-o-wyiyixuxu
andauthored
Update src/diffusers/models/embeddings.py
Co-authored-by: YiYi Xu <[email protected]>
1 parent a137e17 commit da420fb

File tree

1 file changed

+0
-75
lines changed

1 file changed

+0
-75
lines changed

src/diffusers/models/embeddings.py

Lines changed: 0 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2611,78 +2611,3 @@ def forward(self, image_embeds: List[torch.Tensor]):
26112611
projected_image_embeds.append(image_embed)
26122612

26132613
return projected_image_embeds
2614-
2615-
2616-
class CogViewRotary2DEmbedding(nn.Module):
2617-
def __init__(
2618-
self,
2619-
kv_channels: int,
2620-
rotary_percent: float,
2621-
max_h: int = 128,
2622-
max_w: int = 128,
2623-
rotary_interleaved: bool = False,
2624-
seq_len_interpolation_factor: float = None,
2625-
inner_interp: bool = False,
2626-
rotary_base: int = 10000,
2627-
) -> None:
2628-
super().__init__()
2629-
2630-
dim = kv_channels
2631-
if rotary_percent < 1.0:
2632-
dim = int(dim * rotary_percent)
2633-
self.rotary_interleaved = rotary_interleaved
2634-
2635-
self.seq_len_interpolation_factor = seq_len_interpolation_factor
2636-
self.inner_interp = inner_interp
2637-
2638-
dim_h = kv_channels // 2
2639-
dim_w = kv_channels // 2
2640-
2641-
device = torch.cuda.current_device()
2642-
h_inv_freq = 1.0 / (
2643-
rotary_base
2644-
** (torch.arange(0, dim_h, 2, dtype=torch.float32, device=device)[: (dim_h // 2)].float() / dim_h)
2645-
)
2646-
w_inv_freq = 1.0 / (
2647-
rotary_base
2648-
** (torch.arange(0, dim_w, 2, dtype=torch.float32, device=device)[: (dim_w // 2)].float() / dim_w)
2649-
)
2650-
2651-
h_seq = torch.arange(max_h, device=device, dtype=h_inv_freq.dtype)
2652-
w_seq = torch.arange(max_w, device=device, dtype=w_inv_freq.dtype)
2653-
2654-
self.freqs_h = torch.outer(h_seq, h_inv_freq)
2655-
self.freqs_w = torch.outer(w_seq, w_inv_freq)
2656-
self.max_h = max_h
2657-
self.max_w = max_w
2658-
2659-
def forward(
2660-
self,
2661-
h_idx: torch.Tensor,
2662-
w_idx: torch.Tensor,
2663-
target_h: torch.Tensor = None,
2664-
target_w: torch.Tensor = None,
2665-
mask: torch.Tensor = None,
2666-
) -> torch.Tensor:
2667-
if self.inner_interp:
2668-
inner_h_idx = (h_idx * self.max_h) // target_h
2669-
inner_w_idx = (w_idx * self.max_w) // target_w
2670-
2671-
h_emb = self.freqs_h[inner_h_idx]
2672-
w_emb = self.freqs_w[inner_w_idx]
2673-
2674-
else:
2675-
h_emb = self.freqs_h[h_idx]
2676-
w_emb = self.freqs_w[w_idx]
2677-
2678-
mask = (mask == 1).unsqueeze(-1)
2679-
2680-
emb = torch.cat([h_emb, w_emb], dim=-1) * mask
2681-
2682-
assert emb.ndim == 2, f"expected emb to have 2 dimensions, got {emb.ndim}"
2683-
if not self.rotary_interleaved:
2684-
emb = torch.repeat_interleave(emb, 2, dim=0)
2685-
else:
2686-
emb = torch.repeat_interleave(emb, 2, dim=1)
2687-
2688-
return emb

0 commit comments

Comments
 (0)