Skip to content

Commit 7fc465f

Browse files
authored
Fix Graph Breaks When Compiling CogView4
Eliminate this: ``` t]V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] Recompiling function forward in /home/zeyi/repos/diffusers/src/diffusers/models/transformers/transformer_cogview4.py:374 V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] triggered by the following guard failure(s): V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/3: ___check_obj_id(L['self'].rope.freqs_h, 139976127328032) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/2: ___check_obj_id(L['self'].rope.freqs_h, 139976107780960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/1: ___check_obj_id(L['self'].rope.freqs_h, 140022511848960) V0304 10:24:23.421000 3131076 torch/_dynamo/guards.py:2813] [0/4] [__recompiles] - 0/0: ___check_obj_id(L['self'].rope.freqs_h, 140024081342416) ```
1 parent 97fda1b commit 7fc465f

File tree

1 file changed

+4
-6
lines changed

1 file changed

+4
-6
lines changed

src/diffusers/models/transformers/transformer_cogview4.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -252,20 +252,18 @@ def __init__(self, dim: int, patch_size: int, rope_axes_dim: Tuple[int, int], th
252252
w_inv_freq = 1.0 / (theta ** (torch.arange(0, dim_w, 2, dtype=torch.float32)[: (dim_w // 2)].float() / dim_w))
253253
h_seq = torch.arange(self.rope_axes_dim[0])
254254
w_seq = torch.arange(self.rope_axes_dim[1])
255-
self.freqs_h = torch.outer(h_seq, h_inv_freq)
256-
self.freqs_w = torch.outer(w_seq, w_inv_freq)
255+
self.freqs_h = torch.nn.Buffer(torch.outer(h_seq, h_inv_freq))
256+
self.freqs_w = torch.nn.Buffer(torch.outer(w_seq, w_inv_freq))
257257

258258
def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
259259
batch_size, num_channels, height, width = hidden_states.shape
260260
height, width = height // self.patch_size, width // self.patch_size
261261

262-
h_idx = torch.arange(height)
263-
w_idx = torch.arange(width)
262+
h_idx = torch.arange(height, device=self.freqs_h.device)
263+
w_idx = torch.arange(width, device=self.freqs_w.device)
264264
inner_h_idx = h_idx * self.rope_axes_dim[0] // height
265265
inner_w_idx = w_idx * self.rope_axes_dim[1] // width
266266

267-
self.freqs_h = self.freqs_h.to(hidden_states.device)
268-
self.freqs_w = self.freqs_w.to(hidden_states.device)
269267
freqs_h = self.freqs_h[inner_h_idx]
270268
freqs_w = self.freqs_w[inner_w_idx]
271269

0 commit comments

Comments
 (0)