|  | 
| 19 | 19 | import torch | 
| 20 | 20 | from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer | 
| 21 | 21 | 
 | 
| 22 |  | -from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler | 
|  | 22 | +from diffusers import AutoencoderDC, SanaSprintImg2ImgPipeline, SanaTransformer2DModel, SCMScheduler | 
| 23 | 23 | from diffusers.utils.testing_utils import ( | 
| 24 | 24 |     enable_full_determinism, | 
| 25 | 25 |     torch_device, | 
| 26 | 26 | ) | 
| 27 | 27 | 
 | 
| 28 |  | -from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS | 
|  | 28 | +rom ..pipeline_params import ( | 
|  | 29 | +    IMAGE_TO_IMAGE_IMAGE_PARAMS, | 
|  | 30 | +    TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, | 
|  | 31 | +    TEXT_GUIDED_IMAGE_VARIATION_PARAMS, | 
|  | 32 | +) | 
| 29 | 33 | from ..test_pipelines_common import PipelineTesterMixin, to_np | 
| 30 | 34 | 
 | 
| 31 | 35 | 
 | 
| 32 | 36 | enable_full_determinism() | 
| 33 | 37 | 
 | 
| 34 | 38 | 
 | 
| 35 |  | -class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase): | 
| 36 |  | -    pipeline_class = SanaSprintPipeline | 
|  | 39 | +class SanaSprintImg2ImgPipelineFastTests(PipelineTesterMixin, unittest.TestCase): | 
|  | 40 | +    pipeline_class = SanaSprintImg2ImgPipeline | 
| 37 | 41 |     params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"} | 
| 38 |  | -    batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"} | 
| 39 |  | -    image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"} | 
| 40 |  | -    image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS | 
|  | 42 | +    batch_params = TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS - {"negative_prompt"} | 
|  | 43 | +    image_params = IMAGE_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"} | 
|  | 44 | +    image_latents_params = IMAGE_TO_IMAGE_IMAGE_PARAMS | 
| 41 | 45 |     required_optional_params = frozenset( | 
| 42 | 46 |         [ | 
| 43 | 47 |             "num_inference_steps", | 
| @@ -126,12 +130,15 @@ def get_dummy_components(self): | 
| 126 | 130 |         return components | 
| 127 | 131 | 
 | 
| 128 | 132 |     def get_dummy_inputs(self, device, seed=0): | 
|  | 133 | +        image = floats_tensor((1, 3, 32, 32), rng=random.Random(seed)).to(device) | 
| 129 | 134 |         if str(device).startswith("mps"): | 
| 130 | 135 |             generator = torch.manual_seed(seed) | 
| 131 | 136 |         else: | 
| 132 | 137 |             generator = torch.Generator(device=device).manual_seed(seed) | 
| 133 | 138 |         inputs = { | 
| 134 | 139 |             "prompt": "", | 
|  | 140 | +            "image": image, | 
|  | 141 | +            "strength": 0.5, | 
| 135 | 142 |             "generator": generator, | 
| 136 | 143 |             "num_inference_steps": 2, | 
| 137 | 144 |             "guidance_scale": 6.0, | 
|  | 
0 commit comments