|
4 | 4 |
|
5 | 5 | import torch |
6 | 6 |
|
7 | | -from diffusers import EulerDiscreteScheduler, StableDiffusionPipeline |
| 7 | +from diffusers import EulerDiscreteScheduler, StableDiffusionInstructPix2PixPipeline, StableDiffusionPipeline |
8 | 8 | from diffusers.loaders.single_file_utils import _extract_repo_id_and_weights_name |
9 | 9 | from diffusers.utils.testing_utils import ( |
10 | 10 | backend_empty_cache, |
@@ -118,3 +118,39 @@ def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0 |
118 | 118 |
|
119 | 119 | def test_single_file_format_inference_is_same_as_pretrained(self): |
120 | 120 | super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3) |
| 121 | + |
| 122 | + |
| 123 | +@slow |
| 124 | +@require_torch_accelerator |
| 125 | +class StableDiffusionInstructPix2PixPipelineSingleFileSlowTests(unittest.TestCase, SDSingleFileTesterMixin): |
| 126 | + pipeline_class = StableDiffusionInstructPix2PixPipeline |
| 127 | + ckpt_path = "https://huggingface.co/timbrooks/instruct-pix2pix/blob/main/instruct-pix2pix-00-22000.safetensors" |
| 128 | + original_config = ( |
| 129 | + "https://raw.githubusercontent.com/timothybrooks/instruct-pix2pix/refs/heads/main/configs/generate.yaml" |
| 130 | + ) |
| 131 | + repo_id = "timbrooks/instruct-pix2pix" |
| 132 | + |
| 133 | + def setUp(self): |
| 134 | + super().setUp() |
| 135 | + gc.collect() |
| 136 | + backend_empty_cache(torch_device) |
| 137 | + |
| 138 | + def tearDown(self): |
| 139 | + super().tearDown() |
| 140 | + gc.collect() |
| 141 | + backend_empty_cache(torch_device) |
| 142 | + |
| 143 | + def get_inputs(self, device, generator_device="cpu", dtype=torch.float32, seed=0): |
| 144 | + generator = torch.Generator(device=generator_device).manual_seed(seed) |
| 145 | + inputs = { |
| 146 | + "prompt": "a fantasy landscape, concept art, high resolution", |
| 147 | + "generator": generator, |
| 148 | + "num_inference_steps": 2, |
| 149 | + "strength": 0.75, |
| 150 | + "guidance_scale": 7.5, |
| 151 | + "output_type": "np", |
| 152 | + } |
| 153 | + return inputs |
| 154 | + |
| 155 | + def test_single_file_format_inference_is_same_as_pretrained(self): |
| 156 | + super().test_single_file_format_inference_is_same_as_pretrained(expected_max_diff=1e-3) |
0 commit comments