Skip to content

Commit 4dad325

Browse files
committed
initial commit - add img2img test
1 parent 0e2c037 commit 4dad325

File tree

1 file changed

+14
-7
lines changed

1 file changed

+14
-7
lines changed

tests/pipelines/sana/test_sana_sprint_img2img.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,25 +19,29 @@
1919
import torch
2020
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
2121

22-
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
22+
from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler
2323
from diffusers.utils.testing_utils import (
2424
enable_full_determinism,
2525
torch_device,
2626
)
2727

28-
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
28+
rom ..pipeline_params import (
29+
IMAGE_TO_IMAGE_IMAGE_PARAMS,
30+
TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS,
31+
TEXT_GUIDED_IMAGE_VARIATION_PARAMS,
32+
)
2933
from ..test_pipelines_common import PipelineTesterMixin, to_np
3034

3135

3236
enable_full_determinism()
3337

3438

35-
class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
36-
pipeline_class = SanaSprintPipeline
39+
class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
40+
pipeline_class = SanaSprintImg2ImgPipeline
3741
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"}
38-
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"}
39-
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"}
40-
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
42+
batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"negative_prompt"}
43+
image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"}
44+
image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS
4145
required_optional_params = frozenset(
4246
[
4347
"num_inference_steps",
@@ -126,12 +130,15 @@ def get_dummy_components(self):
126130
return components
127131

128132
def get_dummy_inputs(self, device, seed=0):
133+
image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device)
129134
if str(device).startswith("mps"):
130135
generator = torch.manual_seed(seed)
131136
else:
132137
generator = torch.Generator(device=device).manual_seed(seed)
133138
inputs = {
134139
"prompt": "",
140+
"image": image,
141+
"strength": 0.5,
135142
"generator": generator,
136143
"num_inference_steps": 2,
137144
"guidance_scale": 6.0,

0 commit comments

Comments
 (0)