Skip to content

Commit 72e22e8

Browse files
Avoid creating tensor in CosmosAttnProcessor2_0 (#11761)
1 parent 0874dd0 commit 72e22e8

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

src/diffusers/models/transformers/transformer_cosmos.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,9 @@ def __call__(
186186
key = apply_rotary_emb(key, image_rotary_emb, use_real=True, use_real_unbind_dim=-2)
187187

188188
# 4. Prepare for GQA
189-
query_idx = torch.tensor(query.size(3), device=query.device)
190-
key_idx = torch.tensor(key.size(3), device=key.device)
191-
value_idx = torch.tensor(value.size(3), device=value.device)
189+
query_idx = query.size(3)
190+
key_idx = key.size(3)
191+
value_idx = value.size(3)
192192
key = key.repeat_interleave(query_idx // key_idx, dim=3)
193193
value = value.repeat_interleave(query_idx // value_idx, dim=3)
194194

0 commit comments

Comments
 (0)