Skip to content

Commit bfc1dd4

Browse files
committed
fix enable memory efficient attention on ROCm
while calling CK implementation
1 parent 74b6752 commit bfc1dd4

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -399,11 +399,8 @@ def set_use_memory_efficient_attention_xformers(
399399
else:
400400
try:
401401
# Make sure we can run the memory efficient attention
402-
_ = xformers.ops.memory_efficient_attention(
403-
torch.randn((1, 2, 40), device="cuda"),
404-
torch.randn((1, 2, 40), device="cuda"),
405-
torch.randn((1, 2, 40), device="cuda"),
406-
)
402+
q = torch.randn((1, 2, 40), device="cuda", dtype=torch.float16)
403+
_ = xformers.ops.memory_efficient_attention(q, q, q)
407404
except Exception as e:
408405
raise e
409406

0 commit comments

Comments
 (0)