Skip to content

Commit fff9e31

Browse files
committed
up
1 parent 72e22e8 commit fff9e31

File tree

1 file changed

+9
-3
lines changed

1 file changed

+9
-3
lines changed

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,15 @@ def __call__(
186186
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
187187

188188
# 4. Prepare for GQA
189-
query_idx = query.size(3)
190-
key_idx = key.size(3)
191-
value_idx = value.size(3)
189+
if torch.onnx.is_in_onnx_export():
190+
query_idx = torch.tensor(query.size(3), device=query.device)
191+
key_idx = torch.tensor(key.size(3), device=key.device)
192+
value_idx = torch.tensor(value.size(3), device=value.device)
193+
194+
else:
195+
query_idx = query.size(3)
196+
key_idx = key.size(3)
197+
value_idx = value.size(3)
192198
key = key.repeat_interleave(query_idx // key_idx, dim=3)
193199
value = value.repeat_interleave(query_idx // value_idx, dim=3)
194200

0 commit comments

Comments
 (0)