Skip to content

Commit 3fca520

Browse files
Honey-666sayakpaul
andauthored
🎨 fix xl playground device (#8550)
* 🎨 fix xl playground device * 🎨 run `make fix-copies` * 🎨 run `make fix-copies` * edit xl_controlnet_img2img file * edit playground img2img test slow * Update tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py --------- Co-authored-by: Sayak Paul <[email protected]>
1 parent c375903 commit 3fca520

File tree

3 files changed

+59
-4
lines changed

3 files changed

+59
-4
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -949,8 +949,8 @@ def prepare_latents(
949949

950950
init_latents = init_latents.to(dtype)
951951
if latents_mean is not None and latents_std is not None:
952-
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
953-
latents_std = latents_std.to(device=self.device, dtype=dtype)
952+
latents_mean = latents_mean.to(device=device, dtype=dtype)
953+
latents_std = latents_std.to(device=device, dtype=dtype)
954954
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
955955
else:
956956
init_latents = self.vae.config.scaling_factor * init_latents

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -723,8 +723,8 @@ def prepare_latents(
723723

724724
init_latents = init_latents.to(dtype)
725725
if latents_mean is not None and latents_std is not None:
726-
latents_mean = latents_mean.to(device=self.device, dtype=dtype)
727-
latents_std = latents_std.to(device=self.device, dtype=dtype)
726+
latents_mean = latents_mean.to(device=device, dtype=dtype)
727+
latents_std = latents_std.to(device=device, dtype=dtype)
728728
init_latents = (init_latents - latents_mean) * self.vae.config.scaling_factor / latents_std
729729
else:
730730
init_latents = self.vae.config.scaling_factor * init_latents

tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import gc
1617
import random
1718
import unittest
1819

@@ -31,6 +32,7 @@
3132
from diffusers import (
3233
AutoencoderKL,
3334
AutoencoderTiny,
35+
EDMDPMSolverMultistepScheduler,
3436
EulerDiscreteScheduler,
3537
LCMScheduler,
3638
StableDiffusionXLImg2ImgPipeline,
@@ -39,7 +41,9 @@
3941
from diffusers.utils.testing_utils import (
4042
enable_full_determinism,
4143
floats_tensor,
44+
load_image,
4245
require_torch_gpu,
46+
slow,
4347
torch_device,
4448
)
4549

@@ -776,3 +780,54 @@ def test_inference_batch_single_identical(self):
776780

777781
def test_save_load_optional_components(self):
778782
self._test_save_load_optional_components()
783+
784+
785+
@slow
786+
class StableDiffusionXLImg2ImgPipelineIntegrationTests(unittest.TestCase):
787+
def setUp(self):
788+
super().setUp()
789+
gc.collect()
790+
torch.cuda.empty_cache()
791+
792+
def tearDown(self):
793+
super().tearDown()
794+
gc.collect()
795+
torch.cuda.empty_cache()
796+
797+
def test_stable_diffusion_xl_img2img_playground(self):
798+
torch.manual_seed(0)
799+
model_path = "playgroundai/playground-v2.5-1024px-aesthetic"
800+
801+
sd_pipe = StableDiffusionXLImg2ImgPipeline.from_pretrained(
802+
model_path, torch_dtype=torch.float16, variant="fp16", add_watermarker=False
803+
)
804+
805+
sd_pipe.enable_model_cpu_offload()
806+
sd_pipe.scheduler = EDMDPMSolverMultistepScheduler.from_config(
807+
sd_pipe.scheduler.config, use_karras_sigmas=True
808+
)
809+
sd_pipe.set_progress_bar_config(disable=None)
810+
811+
prompt = "a photo of an astronaut riding a horse on mars"
812+
813+
url = "https://huggingface.co/datasets/patrickvonplaten/images/resolve/main/aa_xl/000000009.png"
814+
815+
init_image = load_image(url).convert("RGB")
816+
817+
image = sd_pipe(
818+
prompt,
819+
num_inference_steps=30,
820+
guidance_scale=8.0,
821+
image=init_image,
822+
height=1024,
823+
width=1024,
824+
output_type="np",
825+
).images
826+
827+
image_slice = image[0, -3:, -3:, -1]
828+
829+
assert image.shape == (1, 1024, 1024, 3)
830+
831+
expected_slice = np.array([0.3519, 0.3149, 0.3364, 0.3505, 0.3402, 0.3371, 0.3554, 0.3495, 0.3333])
832+
833+
assert np.abs(image_slice.flatten() - expected_slice).max() < 1e-2

0 commit comments

Comments
 (0)