Skip to content

Commit 9080bd6

Browse files
authored
Update transformer_cogview4.py
1 parent 7fc465f commit 9080bd6

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -252,8 +252,8 @@ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], th
252252
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
253253
h_seq = torch.arange(self.rope_axes_dim[0])
254254
w_seq = torch.arange(self.rope_axes_dim[1])
255-
self.freqs_h = torch.nn.Buffer(torch.outer(h_seq, h_inv_freq))
256-
self.freqs_w = torch.nn.Buffer(torch.outer(w_seq, w_inv_freq))
255+
self.freqs_h = self.register_buffer("freqs_h", torch.outer(h_seq, h_inv_freq), persistent=False)
256+
self.freqs_w = self.register_buffer("freqs_h", torch.outer(w_seq, w_inv_freq), persistent=False)
257257

258258
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
259259
batch_size, num_channels, height, width = hidden_states.shape

0 commit comments

Comments
 (0)