|
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | 15 |
|
| 16 | +import gc |
16 | 17 | import random |
17 | 18 | import unittest |
18 | 19 |
|
|
31 | 32 | from diffusers import ( |
32 | 33 | AutoencoderKL, |
33 | 34 | AutoencoderTiny, |
| 35 | + EDMDPMSolverMultistepScheduler, |
34 | 36 | EulerDiscreteScheduler, |
35 | 37 | LCMScheduler, |
36 | 38 | StableDiffusionXLImg2ImgPipeline, |
|
39 | 41 | from diffusers.utils.testing_utils import ( |
40 | 42 | enable_full_determinism, |
41 | 43 | floats_tensor, |
| 44 | + load_image, |
42 | 45 | require_torch_gpu, |
| 46 | + slow, |
43 | 47 | torch_device, |
44 | 48 | ) |
45 | 49 |
|
@@ -776,3 +780,54 @@ def test_inference_batch_single_identical(self): |
776 | 780 |
|
777 | 781 | def test_save_load_optional_components(self): |
778 | 782 | self._test_save_load_optional_components() |
| 783 | + |
| 784 | + |
| 785 | +@slow |
| 786 | +class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase): |
| 787 | + def setUp(self): |
| 788 | + super().setUp() |
| 789 | + gc.collect() |
| 790 | + torch.cuda.empty_cache() |
| 791 | + |
| 792 | + def tearDown(self): |
| 793 | + super().tearDown() |
| 794 | + gc.collect() |
| 795 | + torch.cuda.empty_cache() |
| 796 | + |
| 797 | + def test_stable_diffusion_xl_img2img_playground(self): |
| 798 | + torch.manual_seed(0) |
| 799 | + model_path = "playgroundai/playground-v2.5-1024px-aesthetic" |
| 800 | + |
| 801 | + sd_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained( |
| 802 | + model_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False |
| 803 | + ) |
| 804 | + |
| 805 | + sd_pipe.enable_model_cpu_offload() |
| 806 | + sd_pipe.scheduler = EDMDPMSolverMultistepScheduler.from_config( |
| 807 | + sd_pipe.scheduler.config, use_karras_sigmas=True |
| 808 | + ) |
| 809 | + sd_pipe.set_progress_bar_config(disable=None) |
| 810 | + |
| 811 | + prompt = "a photo of an astronaut riding a horse on mars" |
| 812 | + |
| 813 | + url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png" |
| 814 | + |
| 815 | + init_image = load_image(url).convert("RGB") |
| 816 | + |
| 817 | + image = sd_pipe( |
| 818 | + prompt, |
| 819 | + num_inference_steps=30, |
| 820 | + guidance_scale=8.0, |
| 821 | + image=init_image, |
| 822 | + height=1024, |
| 823 | + width=1024, |
| 824 | + output_type="np", |
| 825 | + ).images |
| 826 | + |
| 827 | + image_slice = image[0, -3:, -3:, -1] |
| 828 | + |
| 829 | + assert image.shape == (1, 1024, 1024, 3) |
| 830 | + |
| 831 | + expected_slice = np.array([0.3519, 0.3149, 0.3364, 0.3505, 0.3402, 0.3371, 0.3554, 0.3495, 0.3333]) |
| 832 | + |
| 833 | + assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2 |
0 commit comments