|  | 
| 19 | 19 | import numpy as np | 
| 20 | 20 | import pytest | 
| 21 | 21 | from huggingface_hub import hf_hub_download | 
|  | 22 | +from PIL import Image | 
| 22 | 23 | 
 | 
| 23 | 24 | from diffusers import ( | 
| 24 | 25 |     BitsAndBytesConfig, | 
| 25 | 26 |     DiffusionPipeline, | 
|  | 27 | +    FluxControlPipeline, | 
| 26 | 28 |     FluxTransformer2DModel, | 
| 27 | 29 |     SanaTransformer2DModel, | 
| 28 | 30 |     SD3Transformer2DModel, | 
| 29 | 31 |     logging, | 
| 30 | 32 | ) | 
|  | 33 | +from diffusers.quantizers import PipelineQuantizationConfig | 
| 31 | 34 | from diffusers.utils import is_accelerate_version | 
| 32 | 35 | from diffusers.utils.testing_utils import ( | 
| 33 | 36 |     CaptureLogger, | 
|  | 
| 39 | 42 |     numpy_cosine_similarity_distance, | 
| 40 | 43 |     require_accelerate, | 
| 41 | 44 |     require_bitsandbytes_version_greater, | 
|  | 45 | +    require_peft_backend, | 
| 42 | 46 |     require_peft_version_greater, | 
| 43 | 47 |     require_torch, | 
| 44 | 48 |     require_torch_accelerator, | 
| @@ -696,6 +700,50 @@ def test_lora_loading(self): | 
| 696 | 700 |         self.assertTrue(max_diff < 1e-3) | 
| 697 | 701 | 
 | 
| 698 | 702 | 
 | 
|  | 703 | +@require_transformers_version_greater("4.44.0") | 
|  | 704 | +@require_peft_backend | 
|  | 705 | +class SlowBnb4BitFluxControlWithLoraTests(Base8bitTests): | 
|  | 706 | +    def setUp(self) -> None: | 
|  | 707 | +        gc.collect() | 
|  | 708 | +        backend_empty_cache(torch_device) | 
|  | 709 | + | 
|  | 710 | +        self.pipeline_8bit = FluxControlPipeline.from_pretrained( | 
|  | 711 | +            "black-forest-labs/FLUX.1-dev", | 
|  | 712 | +            quantization_config=PipelineQuantizationConfig( | 
|  | 713 | +                quant_backend="bitsandbytes_8bit", | 
|  | 714 | +                quant_kwargs={"load_in_8bit": True}, | 
|  | 715 | +                components_to_quantize=["transformer", "text_encoder_2"], | 
|  | 716 | +            ), | 
|  | 717 | +            torch_dtype=torch.float16, | 
|  | 718 | +        ) | 
|  | 719 | +        self.pipeline_8bit.enable_model_cpu_offload() | 
|  | 720 | + | 
|  | 721 | +    def tearDown(self): | 
|  | 722 | +        del self.pipeline_8bit | 
|  | 723 | + | 
|  | 724 | +        gc.collect() | 
|  | 725 | +        backend_empty_cache(torch_device) | 
|  | 726 | + | 
|  | 727 | +    def test_lora_loading(self): | 
|  | 728 | +        self.pipeline_8bit.load_lora_weights("black-forest-labs/FLUX.1-Canny-dev-lora") | 
|  | 729 | + | 
|  | 730 | +        output = self.pipeline_8bit( | 
|  | 731 | +            prompt=self.prompt, | 
|  | 732 | +            control_image=Image.new(mode="RGB", size=(256, 256)), | 
|  | 733 | +            height=256, | 
|  | 734 | +            width=256, | 
|  | 735 | +            max_sequence_length=64, | 
|  | 736 | +            output_type="np", | 
|  | 737 | +            num_inference_steps=8, | 
|  | 738 | +            generator=torch.Generator().manual_seed(42), | 
|  | 739 | +        ).images | 
|  | 740 | +        out_slice = output[0, -3:, -3:, -1].flatten() | 
|  | 741 | +        expected_slice = np.array([0.2029, 0.2136, 0.2268, 0.1921, 0.1997, 0.2185, 0.2021, 0.2183, 0.2292]) | 
|  | 742 | + | 
|  | 743 | +        max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice) | 
|  | 744 | +        self.assertTrue(max_diff < 1e-3, msg=f"{out_slice=} != {expected_slice=}") | 
|  | 745 | + | 
|  | 746 | + | 
| 699 | 747 | @slow | 
| 700 | 748 | class BaseBnb8bitSerializationTests(Base8bitTests): | 
| 701 | 749 |     def setUp(self): | 
|  | 
0 commit comments