Skip to content

Commit 0e2c037

Browse files
committed
initial commit - add img2img test
1 parent 7db64a1 commit 0e2c037

File tree

1 file changed

+302
-0
lines changed

1 file changed

+302
-0
lines changed
Lines changed: 302 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,302 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
from transformers import Gemma2Config, Gemma2Model, GemmaTokenizer
21+
22+
from diffusers import AutoencoderDC, SanaSprintPipeline, SanaTransformer2DModel, SCMScheduler
23+
from diffusers.utils.testing_utils import (
24+
enable_full_determinism,
25+
torch_device,
26+
)
27+
28+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
29+
from ..test_pipelines_common import PipelineTesterMixin, to_np
30+
31+
32+
enable_full_determinism()
33+
34+
35+
class SanaSprintPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
36+
pipeline_class = SanaSprintPipeline
37+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs", "negative_prompt", "negative_prompt_embeds"}
38+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"}
39+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS - {"negative_prompt"}
40+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
41+
required_optional_params = frozenset(
42+
[
43+
"num_inference_steps",
44+
"generator",
45+
"latents",
46+
"return_dict",
47+
"callback_on_step_end",
48+
"callback_on_step_end_tensor_inputs",
49+
]
50+
)
51+
test_xformers_attention = False
52+
test_layerwise_casting = True
53+
test_group_offloading = True
54+
55+
def get_dummy_components(self):
56+
torch.manual_seed(0)
57+
transformer = SanaTransformer2DModel(
58+
patch_size=1,
59+
in_channels=4,
60+
out_channels=4,
61+
num_layers=1,
62+
num_attention_heads=2,
63+
attention_head_dim=4,
64+
num_cross_attention_heads=2,
65+
cross_attention_head_dim=4,
66+
cross_attention_dim=8,
67+
caption_channels=8,
68+
sample_size=32,
69+
qk_norm="rms_norm_across_heads",
70+
guidance_embeds=True,
71+
)
72+
73+
torch.manual_seed(0)
74+
vae = AutoencoderDC(
75+
in_channels=3,
76+
latent_channels=4,
77+
attention_head_dim=2,
78+
encoder_block_types=(
79+
"ResBlock",
80+
"EfficientViTBlock",
81+
),
82+
decoder_block_types=(
83+
"ResBlock",
84+
"EfficientViTBlock",
85+
),
86+
encoder_block_out_channels=(8, 8),
87+
decoder_block_out_channels=(8, 8),
88+
encoder_qkv_multiscales=((), (5,)),
89+
decoder_qkv_multiscales=((), (5,)),
90+
encoder_layers_per_block=(1, 1),
91+
decoder_layers_per_block=[1, 1],
92+
downsample_block_type="conv",
93+
upsample_block_type="interpolate",
94+
decoder_norm_types="rms_norm",
95+
decoder_act_fns="silu",
96+
scaling_factor=0.41407,
97+
)
98+
99+
torch.manual_seed(0)
100+
scheduler = SCMScheduler()
101+
102+
torch.manual_seed(0)
103+
text_encoder_config = Gemma2Config(
104+
head_dim=16,
105+
hidden_size=8,
106+
initializer_range=0.02,
107+
intermediate_size=64,
108+
max_position_embeddings=8192,
109+
model_type="gemma2",
110+
num_attention_heads=2,
111+
num_hidden_layers=1,
112+
num_key_value_heads=2,
113+
vocab_size=8,
114+
attn_implementation="eager",
115+
)
116+
text_encoder = Gemma2Model(text_encoder_config)
117+
tokenizer = GemmaTokenizer.from_pretrained("hf-internal-testing/dummy-gemma")
118+
119+
components = {
120+
"transformer": transformer,
121+
"vae": vae,
122+
"scheduler": scheduler,
123+
"text_encoder": text_encoder,
124+
"tokenizer": tokenizer,
125+
}
126+
return components
127+
128+
def get_dummy_inputs(self, device, seed=0):
129+
if str(device).startswith("mps"):
130+
generator = torch.manual_seed(seed)
131+
else:
132+
generator = torch.Generator(device=device).manual_seed(seed)
133+
inputs = {
134+
"prompt": "",
135+
"generator": generator,
136+
"num_inference_steps": 2,
137+
"guidance_scale": 6.0,
138+
"height": 32,
139+
"width": 32,
140+
"max_sequence_length": 16,
141+
"output_type": "pt",
142+
"complex_human_instruction": None,
143+
}
144+
return inputs
145+
146+
def test_inference(self):
147+
device = "cpu"
148+
149+
components = self.get_dummy_components()
150+
pipe = self.pipeline_class(**components)
151+
pipe.to(device)
152+
pipe.set_progress_bar_config(disable=None)
153+
154+
inputs = self.get_dummy_inputs(device)
155+
image = pipe(**inputs)[0]
156+
generated_image = image[0]
157+
158+
self.assertEqual(generated_image.shape, (3, 32, 32))
159+
expected_image = torch.randn(3, 32, 32)
160+
max_diff = np.abs(generated_image - expected_image).max()
161+
self.assertLessEqual(max_diff, 1e10)
162+
163+
def test_callback_inputs(self):
164+
sig = inspect.signature(self.pipeline_class.__call__)
165+
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
166+
has_callback_step_end = "callback_on_step_end" in sig.parameters
167+
168+
if not (has_callback_tensor_inputs and has_callback_step_end):
169+
return
170+
171+
components = self.get_dummy_components()
172+
pipe = self.pipeline_class(**components)
173+
pipe = pipe.to(torch_device)
174+
pipe.set_progress_bar_config(disable=None)
175+
self.assertTrue(
176+
hasattr(pipe, "_callback_tensor_inputs"),
177+
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
178+
)
179+
180+
def callback_inputs_subset(pipe, i, t, callback_kwargs):
181+
# iterate over callback args
182+
for tensor_name, tensor_value in callback_kwargs.items():
183+
# check that we're only passing in allowed tensor inputs
184+
assert tensor_name in pipe._callback_tensor_inputs
185+
186+
return callback_kwargs
187+
188+
def callback_inputs_all(pipe, i, t, callback_kwargs):
189+
for tensor_name in pipe._callback_tensor_inputs:
190+
assert tensor_name in callback_kwargs
191+
192+
# iterate over callback args
193+
for tensor_name, tensor_value in callback_kwargs.items():
194+
# check that we're only passing in allowed tensor inputs
195+
assert tensor_name in pipe._callback_tensor_inputs
196+
197+
return callback_kwargs
198+
199+
inputs = self.get_dummy_inputs(torch_device)
200+
201+
# Test passing in a subset
202+
inputs["callback_on_step_end"] = callback_inputs_subset
203+
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
204+
output = pipe(**inputs)[0]
205+
206+
# Test passing in a everything
207+
inputs["callback_on_step_end"] = callback_inputs_all
208+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
209+
output = pipe(**inputs)[0]
210+
211+
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
212+
is_last = i == (pipe.num_timesteps - 1)
213+
if is_last:
214+
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
215+
return callback_kwargs
216+
217+
inputs["callback_on_step_end"] = callback_inputs_change_tensor
218+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
219+
output = pipe(**inputs)[0]
220+
assert output.abs().sum() < 1e10
221+
222+
def test_attention_slicing_forward_pass(
223+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
224+
):
225+
if not self.test_attention_slicing:
226+
return
227+
228+
components = self.get_dummy_components()
229+
pipe = self.pipeline_class(**components)
230+
for component in pipe.components.values():
231+
if hasattr(component, "set_default_attn_processor"):
232+
component.set_default_attn_processor()
233+
pipe.to(torch_device)
234+
pipe.set_progress_bar_config(disable=None)
235+
236+
generator_device = "cpu"
237+
inputs = self.get_dummy_inputs(generator_device)
238+
output_without_slicing = pipe(**inputs)[0]
239+
240+
pipe.enable_attention_slicing(slice_size=1)
241+
inputs = self.get_dummy_inputs(generator_device)
242+
output_with_slicing1 = pipe(**inputs)[0]
243+
244+
pipe.enable_attention_slicing(slice_size=2)
245+
inputs = self.get_dummy_inputs(generator_device)
246+
output_with_slicing2 = pipe(**inputs)[0]
247+
248+
if test_max_difference:
249+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
250+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
251+
self.assertLess(
252+
max(max_diff1, max_diff2),
253+
expected_max_diff,
254+
"Attention slicing should not affect the inference results",
255+
)
256+
257+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
258+
generator_device = "cpu"
259+
components = self.get_dummy_components()
260+
261+
pipe = self.pipeline_class(**components)
262+
pipe.to("cpu")
263+
pipe.set_progress_bar_config(disable=None)
264+
265+
# Without tiling
266+
inputs = self.get_dummy_inputs(generator_device)
267+
inputs["height"] = inputs["width"] = 128
268+
output_without_tiling = pipe(**inputs)[0]
269+
270+
# With tiling
271+
pipe.vae.enable_tiling(
272+
tile_sample_min_height=96,
273+
tile_sample_min_width=96,
274+
tile_sample_stride_height=64,
275+
tile_sample_stride_width=64,
276+
)
277+
inputs = self.get_dummy_inputs(generator_device)
278+
inputs["height"] = inputs["width"] = 128
279+
output_with_tiling = pipe(**inputs)[0]
280+
281+
self.assertLess(
282+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
283+
expected_diff_max,
284+
"VAE tiling should not affect the inference results",
285+
)
286+
287+
# TODO(aryan): Create a dummy gemma model with smol vocab size
288+
@unittest.skip(
289+
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
290+
)
291+
def test_inference_batch_consistent(self):
292+
pass
293+
294+
@unittest.skip(
295+
"A very small vocab size is used for fast tests. So, any kind of prompt other than the empty default used in other tests will lead to a embedding lookup error. This test uses a long prompt that causes the error."
296+
)
297+
def test_inference_batch_single_identical(self):
298+
pass
299+
300+
def test_float16_inference(self):
301+
# Requires higher tolerance as model seems very sensitive to dtype
302+
super().test_float16_inference(expected_max_diff=0.08)

0 commit comments

Comments
 (0)