Skip to content

Commit 7d18503

Browse files
Apply style fixes
1 parent 74b591a commit 7d18503

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,12 @@ def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tens
254254
height, width = height // self.patch_size, width // self.patch_size
255255

256256
dim_h, dim_w = self.dim // 2, self.dim // 2
257-
h_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h))
258-
w_inv_freq = 1.0 / (self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
257+
h_inv_freq = 1.0 / (
258+
self.theta ** (torch.arange(0, dim_h, 2, dtype=torch.float32)[: (dim_h // 2)].float() / dim_h)
259+
)
260+
w_inv_freq = 1.0 / (
261+
self.theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w)
262+
)
259263
h_seq = torch.arange(self.rope_axes_dim[0])
260264
w_seq = torch.arange(self.rope_axes_dim[1])
261265
freqs_h = torch.outer(h_seq, h_inv_freq)

0 commit comments

Comments
 (0)