From 0e0d986533a8c5191b92ef894df1a302d9e1dab1 Mon Sep 17 00:00:00 2001 From: Dhruv Nair Date: Thu, 19 Jun 2025 11:31:49 +0200 Subject: [PATCH] update --- tests/pipelines/sana/test_sana_controlnet.py | 3 ++- tests/pipelines/sana/test_sana_sprint_img2img.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/pipelines/sana/test_sana_controlnet.py b/tests/pipelines/sana/test_sana_controlnet.py index 803f608ba655..9b5c9e439e29 100644 --- a/tests/pipelines/sana/test_sana_controlnet.py +++ b/tests/pipelines/sana/test_sana_controlnet.py @@ -30,6 +30,7 @@ enable_full_determinism, torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS from ..test_pipelines_common import PipelineTesterMixin, to_np @@ -151,7 +152,7 @@ def get_dummy_inputs(self, device, seed=0): else: generator = torch.Generator(device=device).manual_seed(seed) - control_image = torch.randn(1, 3, 32, 32, generator=generator) + control_image = randn_tensor((1, 3, 32, 32), generator=generator, device=device) inputs = { "prompt": "", "negative_prompt": "", diff --git a/tests/pipelines/sana/test_sana_sprint_img2img.py b/tests/pipelines/sana/test_sana_sprint_img2img.py index 1179346d4c10..c0e4bf8e356f 100644 --- a/tests/pipelines/sana/test_sana_sprint_img2img.py +++ b/tests/pipelines/sana/test_sana_sprint_img2img.py @@ -24,6 +24,7 @@ enable_full_determinism, torch_device, ) +from diffusers.utils.torch_utils import randn_tensor from ..pipeline_params import ( IMAGE_TO_IMAGE_IMAGE_PARAMS, @@ -137,7 +138,7 @@ def get_dummy_inputs(self, device, seed=0): generator = torch.manual_seed(seed) else: generator = torch.Generator(device=device).manual_seed(seed) - image = torch.randn(1, 3, 32, 32, generator=generator) + image = randn_tensor((1, 3, 32, 32), generator=generator, device=device) inputs = { "prompt": "", "image": image,