diff --git a/tests/models/autoencoders/test_models_autoencoder_kl.py b/tests/models/autoencoders/test_models_autoencoder_kl.py index 662a3f1b80b7..79aac6c7d1a7 100644 --- a/tests/models/autoencoders/test_models_autoencoder_kl.py +++ b/tests/models/autoencoders/test_models_autoencoder_kl.py @@ -35,13 +35,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin enable_full_determinism() -class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class AutoencoderKLTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = AutoencoderKL main_input_name = "sample" base_precision = 1e-2 diff --git a/tests/models/test_modeling_common.py b/tests/models/test_modeling_common.py index 5e7be62342c3..afb841e865be 100644 --- a/tests/models/test_modeling_common.py +++ b/tests/models/test_modeling_common.py @@ -88,7 +88,12 @@ if is_peft_available(): + from peft import LoraConfig from peft.tuners.tuners_utils import BaseTunerLayer + from peft.utils import get_peft_model_state_dict + + from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY + from diffusers.loaders.peft import PeftAdapterMixin def caculate_expected_num_shards(index_map_path): @@ -1113,177 +1118,6 @@ def test_deprecated_kwargs(self): " from `_deprecated_kwargs = []`" ) - @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) - @torch.no_grad() - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora=False): - from peft import LoraConfig - from peft.utils import get_peft_model_state_dict - - from diffusers.loaders.peft import PeftAdapterMixin - - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - if not issubclass(model.__class__, PeftAdapterMixin): - pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") - - torch.manual_seed(0) - output_no_lora = model(**inputs_dict, return_dict=False)[0] - - denoiser_lora_config = LoraConfig( - r=rank, - lora_alpha=lora_alpha, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], - init_lora_weights=False, - use_dora=use_dora, - ) - model.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - torch.manual_seed(0) - outputs_with_lora = model(**inputs_dict, return_dict=False)[0] - - self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4)) - - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))) - - state_dict_loaded = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")) - - model.unload_lora() - self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") - - for k in state_dict_loaded: - loaded_v = state_dict_loaded[k] - retrieved_v = state_dict_retrieved[k].to(loaded_v.device) - self.assertTrue(torch.allclose(loaded_v, retrieved_v)) - - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - torch.manual_seed(0) - outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] - - self.assertFalse(torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) - self.assertTrue(torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4)) - - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_wrong_adapter_name_raises_error(self): - from peft import LoraConfig - - from diffusers.loaders.peft import PeftAdapterMixin - - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - if not issubclass(model.__class__, PeftAdapterMixin): - pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") - - denoiser_lora_config = LoraConfig( - r=4, - lora_alpha=4, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], - init_lora_weights=False, - use_dora=False, - ) - model.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - with tempfile.TemporaryDirectory() as tmpdir: - wrong_name = "foo" - with self.assertRaises(ValueError) as err_context: - model.save_lora_adapter(tmpdir, adapter_name=wrong_name) - - self.assertTrue(f"Adapter name {wrong_name} not found in the model." in str(err_context.exception)) - - @parameterized.expand([(4, 4, True), (4, 8, False), (8, 4, False)]) - @torch.no_grad() - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora): - from peft import LoraConfig - - from diffusers.loaders.peft import PeftAdapterMixin - - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - if not issubclass(model.__class__, PeftAdapterMixin): - pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") - - denoiser_lora_config = LoraConfig( - r=rank, - lora_alpha=lora_alpha, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], - init_lora_weights=False, - use_dora=use_dora, - ) - model.add_adapter(denoiser_lora_config) - metadata = model.peft_config["default"].to_dict() - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - self.assertTrue(os.path.isfile(model_file)) - - model.unload_lora() - self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - parsed_metadata = model.peft_config["default_0"].to_dict() - check_if_dicts_are_equal(metadata, parsed_metadata) - - @torch.no_grad() - @unittest.skipIf(not is_peft_available(), "Only with PEFT") - def test_lora_adapter_wrong_metadata_raises_error(self): - from peft import LoraConfig - - from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY - from diffusers.loaders.peft import PeftAdapterMixin - - init_dict, _ = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict).to(torch_device) - - if not issubclass(model.__class__, PeftAdapterMixin): - pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") - - denoiser_lora_config = LoraConfig( - r=4, - lora_alpha=4, - target_modules=["to_q", "to_k", "to_v", "to_out.0"], - init_lora_weights=False, - use_dora=False, - ) - model.add_adapter(denoiser_lora_config) - self.assertTrue(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - with tempfile.TemporaryDirectory() as tmpdir: - model.save_lora_adapter(tmpdir) - model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") - self.assertTrue(os.path.isfile(model_file)) - - # Perturb the metadata in the state dict. - loaded_state_dict = safetensors.torch.load_file(model_file) - metadata = {"format": "pt"} - lora_adapter_metadata = denoiser_lora_config.to_dict() - lora_adapter_metadata.update({"foo": 1, "bar": 2}) - for key, value in lora_adapter_metadata.items(): - if isinstance(value, set): - lora_adapter_metadata[key] = list(value) - metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) - safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) - - model.unload_lora() - self.assertFalse(check_if_lora_correctly_set(model), "LoRA layers not set correctly") - - with self.assertRaises(TypeError) as err_context: - model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) - self.assertTrue("`LoraConfig` class could not be instantiated" in str(err_context.exception)) - @require_torch_accelerator def test_cpu_offload(self): if self.model_class._no_split_modules is None: @@ -1941,6 +1775,154 @@ def test_passing_dict_device_map_works(self, name, device): _ = loaded_model(**inputs_dict) +class PEFTTesterMixin: + @require_peft_backend + @pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)]) + @torch.no_grad() + def test_save_load_lora_adapter(self, rank, lora_alpha, use_dora): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") + + torch.manual_seed(0) + output_no_lora = model(**inputs_dict, return_dict=False)[0] + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + torch.manual_seed(0) + outputs_with_lora = model(**inputs_dict, return_dict=False)[0] + assert not torch.allclose(output_no_lora, outputs_with_lora, atol=1e-4, rtol=1e-4) + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file) + + state_dict_loaded = safetensors.torch.load_file(model_file) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + state_dict_retrieved = get_peft_model_state_dict(model, adapter_name="default_0") + + for k, loaded_v in state_dict_loaded.items(): + retrieved_v = state_dict_retrieved[k].to(loaded_v.device) + assert torch.allclose(loaded_v, retrieved_v) + + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + torch.manual_seed(0) + outputs_with_lora_2 = model(**inputs_dict, return_dict=False)[0] + assert not torch.allclose(output_no_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4) + assert torch.allclose(outputs_with_lora, outputs_with_lora_2, atol=1e-4, rtol=1e-4) + + @require_peft_backend + def test_lora_wrong_adapter_name_raises_error(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + with tempfile.TemporaryDirectory() as tmpdir: + wrong_name = "foo" + with pytest.raises(ValueError, match=rf"Adapter name {wrong_name} not found in the model\."): + model.save_lora_adapter(tmpdir, adapter_name=wrong_name) + + @require_peft_backend + @pytest.mark.parametrize("rank,lora_alpha,use_dora", [(4, 4, True), (4, 8, False), (8, 4, False)]) + def test_lora_adapter_metadata_is_loaded_correctly(self, rank, lora_alpha, use_dora): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") + + denoiser_lora_config = LoraConfig( + r=rank, + lora_alpha=lora_alpha, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=use_dora, + ) + model.add_adapter(denoiser_lora_config) + metadata = model.peft_config["default"].to_dict() + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + parsed_metadata = model.peft_config["default_0"].to_dict() + check_if_dicts_are_equal(metadata, parsed_metadata) + + @require_peft_backend + def test_lora_adapter_wrong_metadata_raises_error(self): + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict).to(torch_device) + + if not issubclass(model.__class__, PeftAdapterMixin): + pytest.skip(f"PEFT is not supported for this model ({model.__class__.__name__}).") + + denoiser_lora_config = LoraConfig( + r=4, + lora_alpha=4, + target_modules=["to_q", "to_k", "to_v", "to_out.0"], + init_lora_weights=False, + use_dora=False, + ) + model.add_adapter(denoiser_lora_config) + assert check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + with tempfile.TemporaryDirectory() as tmpdir: + model.save_lora_adapter(tmpdir) + model_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors") + assert os.path.isfile(model_file) + + # Perturb the metadata + loaded_state_dict = safetensors.torch.load_file(model_file) + metadata = {"format": "pt"} + lora_adapter_metadata = denoiser_lora_config.to_dict() + lora_adapter_metadata.update({"foo": 1, "bar": 2}) + for key, value in list(lora_adapter_metadata.items()): + if isinstance(value, set): + lora_adapter_metadata[key] = list(value) + metadata[LORA_ADAPTER_METADATA_KEY] = json.dumps(lora_adapter_metadata, indent=2, sort_keys=True) + safetensors.torch.save_file(loaded_state_dict, model_file, metadata=metadata) + + model.unload_lora() + assert not check_if_lora_correctly_set(model), "LoRA layers not set correctly" + + with pytest.raises(TypeError, match=r"`LoraConfig` class could not be instantiated"): + model.load_lora_adapter(tmpdir, prefix=None, use_safetensors=True) + + @is_staging_test class ModelPushToHubTester(unittest.TestCase): identifier = uuid.uuid4() diff --git a/tests/models/transformers/test_models_prior.py b/tests/models/transformers/test_models_prior.py index af5ac4bbbd76..ec5829d31463 100644 --- a/tests/models/transformers/test_models_prior.py +++ b/tests/models/transformers/test_models_prior.py @@ -30,13 +30,13 @@ torch_all_close, torch_device, ) -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class PriorTransformerTests(ModelTesterMixin, unittest.TestCase): +class PriorTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = PriorTransformer main_input_name = "hidden_states" diff --git a/tests/models/transformers/test_models_transformer_aura_flow.py b/tests/models/transformers/test_models_transformer_aura_flow.py index ae8c3b7234a3..9c487f1d9dba 100644 --- a/tests/models/transformers/test_models_transformer_aura_flow.py +++ b/tests/models/transformers/test_models_transformer_aura_flow.py @@ -20,13 +20,13 @@ from diffusers import AuraFlowTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class AuraFlowTransformerTests(ModelTesterMixin, unittest.TestCase): +class AuraFlowTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = AuraFlowTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. diff --git a/tests/models/transformers/test_models_transformer_bria.py b/tests/models/transformers/test_models_transformer_bria.py index 9056590edffe..2ebee1744358 100644 --- a/tests/models/transformers/test_models_transformer_bria.py +++ b/tests/models/transformers/test_models_transformer_bria.py @@ -22,7 +22,12 @@ from diffusers.models.embeddings import ImageProjection from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import ( + LoraHotSwappingForModelTesterMixin, + ModelTesterMixin, + PEFTTesterMixin, + TorchCompileTesterMixin, +) enable_full_determinism() @@ -78,7 +83,7 @@ def create_bria_ip_adapter_state_dict(model): return ip_state_dict -class BriaTransformerTests(ModelTesterMixin, unittest.TestCase): +class BriaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = BriaTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. diff --git a/tests/models/transformers/test_models_transformer_chroma.py b/tests/models/transformers/test_models_transformer_chroma.py index 92ac8198ed06..2377ea0031b9 100644 --- a/tests/models/transformers/test_models_transformer_chroma.py +++ b/tests/models/transformers/test_models_transformer_chroma.py @@ -22,7 +22,12 @@ from diffusers.models.embeddings import ImageProjection from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import ( + LoraHotSwappingForModelTesterMixin, + ModelTesterMixin, + PEFTTesterMixin, + TorchCompileTesterMixin, +) enable_full_determinism() @@ -78,7 +83,7 @@ def create_chroma_ip_adapter_state_dict(model): return ip_state_dict -class ChromaTransformerTests(ModelTesterMixin, unittest.TestCase): +class ChromaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = ChromaTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. diff --git a/tests/models/transformers/test_models_transformer_cogvideox.py b/tests/models/transformers/test_models_transformer_cogvideox.py index f632add7e5a7..72e96796eb02 100644 --- a/tests/models/transformers/test_models_transformer_cogvideox.py +++ b/tests/models/transformers/test_models_transformer_cogvideox.py @@ -19,17 +19,14 @@ from diffusers import CogVideoXTransformer3DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class CogVideoXTransformerTests(ModelTesterMixin, unittest.TestCase): +class CogVideoXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = CogVideoXTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_cogview4.py b/tests/models/transformers/test_models_transformer_cogview4.py index 084c3b7cea41..57ea90a1b9c5 100644 --- a/tests/models/transformers/test_models_transformer_cogview4.py +++ b/tests/models/transformers/test_models_transformer_cogview4.py @@ -19,13 +19,13 @@ from diffusers import CogView4Transformer2DModel from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase): +class CogView3PlusTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = CogView4Transformer2DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_consisid.py b/tests/models/transformers/test_models_transformer_consisid.py index 77fc172d078a..9fd858b2341f 100644 --- a/tests/models/transformers/test_models_transformer_consisid.py +++ b/tests/models/transformers/test_models_transformer_consisid.py @@ -19,17 +19,14 @@ from diffusers import ConsisIDTransformer3DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class ConsisIDTransformerTests(ModelTesterMixin, unittest.TestCase): +class ConsisIDTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = ConsisIDTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_flux.py b/tests/models/transformers/test_models_transformer_flux.py index 3ab02f797b5b..add10634dc61 100644 --- a/tests/models/transformers/test_models_transformer_flux.py +++ b/tests/models/transformers/test_models_transformer_flux.py @@ -22,7 +22,12 @@ from diffusers.models.embeddings import ImageProjection from ...testing_utils import enable_full_determinism, is_peft_available, torch_device -from ..test_modeling_common import LoraHotSwappingForModelTesterMixin, ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import ( + LoraHotSwappingForModelTesterMixin, + ModelTesterMixin, + PEFTTesterMixin, + TorchCompileTesterMixin, +) enable_full_determinism() @@ -80,7 +85,7 @@ def create_flux_ip_adapter_state_dict(model): return ip_state_dict -class FluxTransformerTests(ModelTesterMixin, unittest.TestCase): +class FluxTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = FluxTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. diff --git a/tests/models/transformers/test_models_transformer_hidream.py b/tests/models/transformers/test_models_transformer_hidream.py index fdd5f8c7fd07..c0b71b0626a0 100644 --- a/tests/models/transformers/test_models_transformer_hidream.py +++ b/tests/models/transformers/test_models_transformer_hidream.py @@ -19,17 +19,14 @@ from diffusers import HiDreamImageTransformer2DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class HiDreamTransformerTests(ModelTesterMixin, unittest.TestCase): +class HiDreamTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = HiDreamImageTransformer2DModel main_input_name = "hidden_states" model_split_percents = [0.8, 0.8, 0.9] diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video.py b/tests/models/transformers/test_models_transformer_hunyuan_video.py index 385a5eefd58b..c04f0bcadfa0 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video.py @@ -18,17 +18,14 @@ from diffusers import HunyuanVideoTransformer3DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): +class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = HunyuanVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py index 00a2b27e02b6..417d3bb0080a 100644 --- a/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py +++ b/tests/models/transformers/test_models_transformer_hunyuan_video_framepack.py @@ -18,17 +18,14 @@ from diffusers import HunyuanVideoFramepackTransformer3DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class HunyuanVideoTransformer3DTests(ModelTesterMixin, unittest.TestCase): +class HunyuanVideoTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = HunyuanVideoFramepackTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_ltx.py b/tests/models/transformers/test_models_transformer_ltx.py index e912463bbf6a..8679e63e5bc8 100644 --- a/tests/models/transformers/test_models_transformer_ltx.py +++ b/tests/models/transformers/test_models_transformer_ltx.py @@ -20,13 +20,13 @@ from diffusers import LTXVideoTransformer3DModel from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class LTXTransformerTests(ModelTesterMixin, unittest.TestCase): +class LTXTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = LTXVideoTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_lumina2.py b/tests/models/transformers/test_models_transformer_lumina2.py index 4efae3d4b713..657e61531bcb 100644 --- a/tests/models/transformers/test_models_transformer_lumina2.py +++ b/tests/models/transformers/test_models_transformer_lumina2.py @@ -19,17 +19,14 @@ from diffusers import Lumina2Transformer2DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, unittest.TestCase): +class Lumina2Transformer2DModelTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = Lumina2Transformer2DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_mochi.py b/tests/models/transformers/test_models_transformer_mochi.py index 931b5874ee78..6d38407f9257 100644 --- a/tests/models/transformers/test_models_transformer_mochi.py +++ b/tests/models/transformers/test_models_transformer_mochi.py @@ -20,13 +20,13 @@ from diffusers import MochiTransformer3DModel from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class MochiTransformerTests(ModelTesterMixin, unittest.TestCase): +class MochiTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = MochiTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_qwenimage.py b/tests/models/transformers/test_models_transformer_qwenimage.py index b24fa90503ef..2ff99b9b4a20 100644 --- a/tests/models/transformers/test_models_transformer_qwenimage.py +++ b/tests/models/transformers/test_models_transformer_qwenimage.py @@ -21,13 +21,13 @@ from diffusers import QwenImageTransformer2DModel from ...testing_utils import enable_full_determinism, torch_device -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class QwenImageTransformerTests(ModelTesterMixin, unittest.TestCase): +class QwenImageTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = QwenImageTransformer2DModel main_input_name = "hidden_states" # We override the items here because the transformer under consideration is small. diff --git a/tests/models/transformers/test_models_transformer_sana.py b/tests/models/transformers/test_models_transformer_sana.py index 2e316c3aedc1..451186f29ca8 100644 --- a/tests/models/transformers/test_models_transformer_sana.py +++ b/tests/models/transformers/test_models_transformer_sana.py @@ -18,17 +18,14 @@ from diffusers import SanaTransformer2DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class SanaTransformerTests(ModelTesterMixin, unittest.TestCase): +class SanaTransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = SanaTransformer2DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_sd3.py b/tests/models/transformers/test_models_transformer_sd3.py index c4ee7017a380..f785021d8be9 100644 --- a/tests/models/transformers/test_models_transformer_sd3.py +++ b/tests/models/transformers/test_models_transformer_sd3.py @@ -24,13 +24,13 @@ enable_full_determinism, torch_device, ) -from ..test_modeling_common import ModelTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin enable_full_determinism() -class SD3TransformerTests(ModelTesterMixin, unittest.TestCase): +class SD3TransformerTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = SD3Transformer2DModel main_input_name = "hidden_states" model_split_percents = [0.8, 0.8, 0.9] diff --git a/tests/models/transformers/test_models_transformer_skyreels_v2.py b/tests/models/transformers/test_models_transformer_skyreels_v2.py index 8c36d8256ee9..ea0b4f34e600 100644 --- a/tests/models/transformers/test_models_transformer_skyreels_v2.py +++ b/tests/models/transformers/test_models_transformer_skyreels_v2.py @@ -18,17 +18,14 @@ from diffusers import SkyReelsV2Transformer3DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class SkyReelsV2Transformer3DTests(ModelTesterMixin, TorchCompileTesterMixin, unittest.TestCase): +class SkyReelsV2Transformer3DTests(ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin, unittest.TestCase): model_class = SkyReelsV2Transformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/transformers/test_models_transformer_wan.py b/tests/models/transformers/test_models_transformer_wan.py index 9f248f990c8a..1671b157079a 100644 --- a/tests/models/transformers/test_models_transformer_wan.py +++ b/tests/models/transformers/test_models_transformer_wan.py @@ -18,17 +18,14 @@ from diffusers import WanTransformer3DModel -from ...testing_utils import ( - enable_full_determinism, - torch_device, -) -from ..test_modeling_common import ModelTesterMixin, TorchCompileTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, TorchCompileTesterMixin enable_full_determinism() -class WanTransformer3DTests(ModelTesterMixin, unittest.TestCase): +class WanTransformer3DTests(ModelTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = WanTransformer3DModel main_input_name = "hidden_states" uses_custom_attn_processor = True diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 4dbb8ca7c075..25818d6b8204 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -55,6 +55,7 @@ from ..test_modeling_common import ( LoraHotSwappingForModelTesterMixin, ModelTesterMixin, + PEFTTesterMixin, TorchCompileTesterMixin, UNetTesterMixin, ) @@ -354,7 +355,7 @@ def create_custom_diffusion_layers(model, mock_weights: bool = True): return custom_diffusion_attn_procs -class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, PEFTTesterMixin, unittest.TestCase): model_class = UNet2DConditionModel main_input_name = "sample" # We override the items here because the unet under consideration is small. @@ -1083,48 +1084,6 @@ def test_load_sharded_checkpoint_device_map_from_hub_local_subfolder(self): assert loaded_model assert new_output.sample.shape == (4, 4, 16, 16) - @require_peft_backend - def test_load_attn_procs_raise_warning(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() - model = self.model_class(**init_dict) - model.to(torch_device) - - # forward pass without LoRA - with torch.no_grad(): - non_lora_sample = model(**inputs_dict).sample - - unet_lora_config = get_unet_lora_config() - model.add_adapter(unet_lora_config) - - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - # forward pass with LoRA - with torch.no_grad(): - lora_sample_1 = model(**inputs_dict).sample - - with tempfile.TemporaryDirectory() as tmpdirname: - model.save_attn_procs(tmpdirname) - model.unload_lora() - - with self.assertWarns(FutureWarning) as warning: - model.load_attn_procs(os.path.join(tmpdirname, "pytorch_lora_weights.safetensors")) - - warning_message = str(warning.warnings[0].message) - assert "Using the `load_attn_procs()` method has been deprecated" in warning_message - - # import to still check for the rest of the stuff. - assert check_if_lora_correctly_set(model), "Lora not correctly set in UNet." - - with torch.no_grad(): - lora_sample_2 = model(**inputs_dict).sample - - assert not torch.allclose(non_lora_sample, lora_sample_1, atol=1e-4, rtol=1e-4), ( - "LoRA injected UNet should produce different results." - ) - assert torch.allclose(lora_sample_1, lora_sample_2, atol=1e-4, rtol=1e-4), ( - "Loading from a saved checkpoint should produce identical results." - ) - @require_peft_backend def test_save_attn_procs_raise_warning(self): init_dict, _ = self.prepare_init_args_and_inputs_for_common() diff --git a/tests/models/unets/test_models_unet_motion.py b/tests/models/unets/test_models_unet_motion.py index d931b345fd09..76b93b3a2449 100644 --- a/tests/models/unets/test_models_unet_motion.py +++ b/tests/models/unets/test_models_unet_motion.py @@ -30,7 +30,7 @@ floats_tensor, torch_device, ) -from ..test_modeling_common import ModelTesterMixin, UNetTesterMixin +from ..test_modeling_common import ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin logger = logging.get_logger(__name__) @@ -38,7 +38,7 @@ enable_full_determinism() -class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase): +class UNetMotionModelTests(ModelTesterMixin, PEFTTesterMixin, UNetTesterMixin, unittest.TestCase): model_class = UNetMotionModel main_input_name = "sample"