File tree Expand file tree Collapse file tree 1 file changed +16
-0
lines changed
tests/models/transformers Expand file tree Collapse file tree 1 file changed +16
-0
lines changed Original file line number Diff line number Diff line change 1818import torch
1919
2020from diffusers import SD3Transformer2DModel
21+ from diffusers .utils .import_utils import is_xformers_available
2122from diffusers .utils .testing_utils import (
2223 enable_full_determinism ,
2324 torch_device ,
@@ -76,3 +77,18 @@ def prepare_init_args_and_inputs_for_common(self):
7677 }
7778 inputs_dict = self .dummy_input
7879 return init_dict , inputs_dict
80+
81+ @unittest .skipIf (
82+ torch_device != "cuda" or not is_xformers_available (),
83+ reason = "XFormers attention is only available with CUDA and `xformers` installed" ,
84+ )
85+ def test_xformers_enable_works (self ):
86+ init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
87+ model = self .model_class (** init_dict )
88+
89+ model .enable_xformers_memory_efficient_attention ()
90+
91+ assert (
92+ model .transformer_blocks [0 ].attn .processor .__class__ .__name__
93+ == "XFormersJointAttnProcessor"
94+ ), "xformers is not enabled"
You can’t perform that action at this time.
0 commit comments