Skip to content

Commit 479d9d2

Browse files
committed
small changes
1 parent 5297450 commit 479d9d2

File tree

1 file changed

+2
-5
lines changed

1 file changed

+2
-5
lines changed

tests/pipelines/sana/test_sana_sprint_img2img.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import inspect
16-
import random
1716
import unittest
1817

1918
import numpy as np
@@ -23,7 +22,6 @@
2322
from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler
2423
from diffusers.utils.testing_utils import (
2524
enable_full_determinism,
26-
floats_tensor,
2725
torch_device,
2826
)
2927

@@ -41,12 +39,11 @@
4139
class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
4240
pipeline_class = SanaSprintImg2ImgPipeline
4341
params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {
44-
"cross_attention_kwargs",
4542
"negative_prompt",
4643
"negative_prompt_embeds",
4744
}
4845
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"negative_prompt"}
49-
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"}
46+
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
5047
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
5148
required_optional_params = frozenset(
5249
[
@@ -136,7 +133,7 @@ def get_dummy_components(self):
136133
return components
137134

138135
def get_dummy_inputs(self, device, seed=0):
139-
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
136+
image = torch.randn(1, 3, 32, 32, generator=generator)
140137
if str(device).startswith("mps"):
141138
generator = torch.manual_seed(seed)
142139
else:

0 commit comments

Comments
 (0)