Skip to content

Commit 74b591a

Browse files
committed
fix cogview4 rotary pos embed
1 parent 9080bd6 commit 74b591a

File tree

1 file changed

+14
-12
lines changed

1 file changed

+14
-12
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -244,28 +244,30 @@ class CogView4RotaryPosEmbed(nn.Module):
244244
def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], theta: float = 10000.0) -> None:
245245
super().__init__()
246246

247+
self.dim = dim
247248
self.patch_size = patch_size
248249
self.rope_axes_dim = rope_axes_dim
249-
250-
dim_h, dim_w = dim // 2, dim // 2
251-
h_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
252-
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
253-
h_seq = torch.arange(self.rope_axes_dim[0])
254-
w_seq = torch.arange(self.rope_axes_dim[1])
255-
self.freqs_h = self.register_buffer("freqs_h", torch.outer(h_seq, h_inv_freq), persistent=False)
256-
self.freqs_w = self.register_buffer("freqs_h", torch.outer(w_seq, w_inv_freq), persistent=False)
250+
self.theta = theta
257251

258252
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
259253
batch_size, num_channels, height, width = hidden_states.shape
260254
height, width = height // self.patch_size, width // self.patch_size
261255

262-
h_idx = torch.arange(height, device=self.freqs_h.device)
263-
w_idx = torch.arange(width, device=self.freqs_w.device)
256+
dim_h, dim_w = self.dim // 2, self.dim // 2
257+
h_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
258+
w_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
259+
h_seq = torch.arange(self.rope_axes_dim[0])
260+
w_seq = torch.arange(self.rope_axes_dim[1])
261+
freqs_h = torch.outer(h_seq, h_inv_freq)
262+
freqs_w = torch.outer(w_seq, w_inv_freq)
263+
264+
h_idx = torch.arange(height, device=freqs_h.device)
265+
w_idx = torch.arange(width, device=freqs_w.device)
264266
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
265267
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
266268

267-
freqs_h = self.freqs_h[inner_h_idx]
268-
freqs_w = self.freqs_w[inner_w_idx]
269+
freqs_h = freqs_h[inner_h_idx]
270+
freqs_w = freqs_w[inner_w_idx]
269271

270272
# Create position matrices for height and width
271273
# [height, 1, dim//4] and [1, width, dim//4]

0 commit comments

Comments
 (0)