Skip to content

Commit 0534e23

Browse files
committed
sd3 xformers test
1 parent 78b269a commit 0534e23

File tree

1 file changed

+16
-0
lines changed

1 file changed

+16
-0
lines changed

tests/models/transformers/test_models_transformer_sd3.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import torch
1919

2020
from diffusers import SD3Transformer2DModel
21+
from diffusers.utils.import_utils import is_xformers_available
2122
from diffusers.utils.testing_utils import (
2223
enable_full_determinism,
2324
torch_device,
@@ -76,3 +77,18 @@ def prepare_init_args_and_inputs_for_common(self):
7677
}
7778
inputs_dict = self.dummy_input
7879
return init_dict, inputs_dict
80+
81+
@unittest.skipIf(
82+
torch_device != "cuda" or not is_xformers_available(),
83+
reason="XFormers attention is only available with CUDA and `xformers` installed",
84+
)
85+
def test_xformers_enable_works(self):
86+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
87+
model = self.model_class(**init_dict)
88+
89+
model.enable_xformers_memory_efficient_attention()
90+
91+
assert (
92+
model.transformer_blocks[0].attn.processor.__class__.__name__
93+
== "XFormersJointAttnProcessor"
94+
), "xformers is not enabled"

0 commit comments

Comments
 (0)