Skip to content

Commit 25727e4

Browse files
authored
Merge branch 'main' into inplace_sum_and_remove_padding_and_better_memory_count
2 parents e264448 + 1ae9b05 commit 25727e4

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -405,11 +405,12 @@ def set_use_memory_efficient_attention_xformers(
405405
else:
406406
try:
407407
# Make sure we can run the memory efficient attention
408-
_ = xformers.ops.memory_efficient_attention(
409-
torch.randn((1, 2, 40), device="cuda"),
410-
torch.randn((1, 2, 40), device="cuda"),
411-
torch.randn((1, 2, 40), device="cuda"),
412-
)
408+
dtype = None
409+
if attention_op is not None:
410+
op_fw, op_bw = attention_op
411+
dtype, *_ = op_fw.SUPPORTED_DTYPES
412+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
413+
_ = xformers.ops.memory_efficient_attention(q, q, q)
413414
except Exception as e:
414415
raise e
415416

0 commit comments

Comments
 (0)