Skip to content

Commit 714f89d

Browse files
committed
add t2i test
1 parent 4f8c133 commit 714f89d

File tree

1 file changed

+342
-0
lines changed

1 file changed

+342
-0
lines changed
Lines changed: 342 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,342 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
import json
17+
import os
18+
import tempfile
19+
import unittest
20+
21+
import numpy as np
22+
import torch
23+
from transformers import AutoTokenizer, T5EncoderModel
24+
25+
from diffusers import AutoencoderKLWan, Cosmos2TextToImagePipeline, CosmosTransformer3DModel, EDMEulerScheduler
26+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
27+
28+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
29+
from ..test_pipelines_common import PipelineTesterMixin, to_np
30+
from .cosmos_guardrail import DummyCosmosSafetyChecker
31+
32+
33+
enable_full_determinism()
34+
35+
36+
class Cosmos2TextToImagePipelineWrapper(Cosmos2TextToImagePipeline):
37+
@staticmethod
38+
def from_pretrained(*args, **kwargs):
39+
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
40+
return Cosmos2TextToImagePipeline.from_pretrained(*args, **kwargs)
41+
42+
43+
class Cosmos2TextToImagePipelineFastTests(PipelineTesterMixin, unittest.TestCase):
44+
pipeline_class = Cosmos2TextToImagePipelineWrapper
45+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
46+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
47+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
48+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
49+
required_optional_params = frozenset(
50+
[
51+
"num_inference_steps",
52+
"generator",
53+
"latents",
54+
"return_dict",
55+
"callback_on_step_end",
56+
"callback_on_step_end_tensor_inputs",
57+
]
58+
)
59+
supports_dduf = False
60+
test_xformers_attention = False
61+
test_layerwise_casting = True
62+
test_group_offloading = True
63+
64+
def get_dummy_components(self):
65+
torch.manual_seed(0)
66+
transformer = CosmosTransformer3DModel(
67+
in_channels=16,
68+
out_channels=16,
69+
num_attention_heads=2,
70+
attention_head_dim=16,
71+
num_layers=2,
72+
mlp_ratio=2,
73+
text_embed_dim=32,
74+
adaln_lora_dim=4,
75+
max_size=(4, 32, 32),
76+
patch_size=(1, 2, 2),
77+
rope_scale=(2.0, 1.0, 1.0),
78+
concat_padding_mask=True,
79+
extra_pos_embed_type="learnable",
80+
)
81+
82+
torch.manual_seed(0)
83+
vae = AutoencoderKLWan(
84+
base_dim=3,
85+
z_dim=16,
86+
dim_mult=[1, 1, 1, 1],
87+
num_res_blocks=1,
88+
temperal_downsample=[False, True, True],
89+
)
90+
91+
torch.manual_seed(0)
92+
scheduler = EDMEulerScheduler(
93+
sigma_min=0.002,
94+
sigma_max=80,
95+
sigma_data=0.5,
96+
sigma_schedule="karras",
97+
num_train_timesteps=1000,
98+
prediction_type="epsilon",
99+
rho=7.0,
100+
final_sigmas_type="sigma_min",
101+
use_flow_sigmas=True,
102+
)
103+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
104+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
105+
106+
components = {
107+
"transformer": transformer,
108+
"vae": vae,
109+
"scheduler": scheduler,
110+
"text_encoder": text_encoder,
111+
"tokenizer": tokenizer,
112+
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
113+
"safety_checker": DummyCosmosSafetyChecker(),
114+
}
115+
return components
116+
117+
def get_dummy_inputs(self, device, seed=0):
118+
if str(device).startswith("mps"):
119+
generator = torch.manual_seed(seed)
120+
else:
121+
generator = torch.Generator(device=device).manual_seed(seed)
122+
123+
inputs = {
124+
"prompt": "dance monkey",
125+
"negative_prompt": "bad quality",
126+
"generator": generator,
127+
"num_inference_steps": 2,
128+
"guidance_scale": 3.0,
129+
"height": 32,
130+
"width": 32,
131+
"max_sequence_length": 16,
132+
"output_type": "pt",
133+
}
134+
135+
return inputs
136+
137+
def test_inference(self):
138+
device = "cpu"
139+
140+
components = self.get_dummy_components()
141+
pipe = self.pipeline_class(**components)
142+
pipe.to(device)
143+
pipe.set_progress_bar_config(disable=None)
144+
145+
inputs = self.get_dummy_inputs(device)
146+
image = pipe(**inputs).images
147+
generated_image = image[0]
148+
149+
self.assertEqual(generated_image.shape, (3, 32, 32))
150+
expected_video = torch.randn(3, 32, 32)
151+
max_diff = np.abs(generated_image - expected_video).max()
152+
self.assertLessEqual(max_diff, 1e10)
153+
154+
def test_callback_inputs(self):
155+
sig = inspect.signature(self.pipeline_class.__call__)
156+
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
157+
has_callback_step_end = "callback_on_step_end" in sig.parameters
158+
159+
if not (has_callback_tensor_inputs and has_callback_step_end):
160+
return
161+
162+
components = self.get_dummy_components()
163+
pipe = self.pipeline_class(**components)
164+
pipe = pipe.to(torch_device)
165+
pipe.set_progress_bar_config(disable=None)
166+
self.assertTrue(
167+
hasattr(pipe, "_callback_tensor_inputs"),
168+
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
169+
)
170+
171+
def callback_inputs_subset(pipe, i, t, callback_kwargs):
172+
# iterate over callback args
173+
for tensor_name, tensor_value in callback_kwargs.items():
174+
# check that we're only passing in allowed tensor inputs
175+
assert tensor_name in pipe._callback_tensor_inputs
176+
177+
return callback_kwargs
178+
179+
def callback_inputs_all(pipe, i, t, callback_kwargs):
180+
for tensor_name in pipe._callback_tensor_inputs:
181+
assert tensor_name in callback_kwargs
182+
183+
# iterate over callback args
184+
for tensor_name, tensor_value in callback_kwargs.items():
185+
# check that we're only passing in allowed tensor inputs
186+
assert tensor_name in pipe._callback_tensor_inputs
187+
188+
return callback_kwargs
189+
190+
inputs = self.get_dummy_inputs(torch_device)
191+
192+
# Test passing in a subset
193+
inputs["callback_on_step_end"] = callback_inputs_subset
194+
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
195+
output = pipe(**inputs)[0]
196+
197+
# Test passing in a everything
198+
inputs["callback_on_step_end"] = callback_inputs_all
199+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
200+
output = pipe(**inputs)[0]
201+
202+
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
203+
is_last = i == (pipe.num_timesteps - 1)
204+
if is_last:
205+
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
206+
return callback_kwargs
207+
208+
inputs["callback_on_step_end"] = callback_inputs_change_tensor
209+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
210+
output = pipe(**inputs)[0]
211+
assert output.abs().sum() < 1e10
212+
213+
def test_inference_batch_single_identical(self):
214+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
215+
216+
def test_attention_slicing_forward_pass(
217+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
218+
):
219+
if not self.test_attention_slicing:
220+
return
221+
222+
components = self.get_dummy_components()
223+
pipe = self.pipeline_class(**components)
224+
for component in pipe.components.values():
225+
if hasattr(component, "set_default_attn_processor"):
226+
component.set_default_attn_processor()
227+
pipe.to(torch_device)
228+
pipe.set_progress_bar_config(disable=None)
229+
230+
generator_device = "cpu"
231+
inputs = self.get_dummy_inputs(generator_device)
232+
output_without_slicing = pipe(**inputs)[0]
233+
234+
pipe.enable_attention_slicing(slice_size=1)
235+
inputs = self.get_dummy_inputs(generator_device)
236+
output_with_slicing1 = pipe(**inputs)[0]
237+
238+
pipe.enable_attention_slicing(slice_size=2)
239+
inputs = self.get_dummy_inputs(generator_device)
240+
output_with_slicing2 = pipe(**inputs)[0]
241+
242+
if test_max_difference:
243+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
244+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
245+
self.assertLess(
246+
max(max_diff1, max_diff2),
247+
expected_max_diff,
248+
"Attention slicing should not affect the inference results",
249+
)
250+
251+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
252+
generator_device = "cpu"
253+
components = self.get_dummy_components()
254+
255+
pipe = self.pipeline_class(**components)
256+
pipe.to("cpu")
257+
pipe.set_progress_bar_config(disable=None)
258+
259+
# Without tiling
260+
inputs = self.get_dummy_inputs(generator_device)
261+
inputs["height"] = inputs["width"] = 128
262+
output_without_tiling = pipe(**inputs)[0]
263+
264+
# With tiling
265+
pipe.vae.enable_tiling(
266+
tile_sample_min_height=96,
267+
tile_sample_min_width=96,
268+
tile_sample_stride_height=64,
269+
tile_sample_stride_width=64,
270+
)
271+
inputs = self.get_dummy_inputs(generator_device)
272+
inputs["height"] = inputs["width"] = 128
273+
output_with_tiling = pipe(**inputs)[0]
274+
275+
self.assertLess(
276+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
277+
expected_diff_max,
278+
"VAE tiling should not affect the inference results",
279+
)
280+
281+
def test_save_load_optional_components(self, expected_max_difference=1e-4):
282+
self.pipeline_class._optional_components.remove("safety_checker")
283+
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
284+
self.pipeline_class._optional_components.append("safety_checker")
285+
286+
def test_serialization_with_variants(self):
287+
components = self.get_dummy_components()
288+
pipe = self.pipeline_class(**components)
289+
model_components = [
290+
component_name
291+
for component_name, component in pipe.components.items()
292+
if isinstance(component, torch.nn.Module)
293+
]
294+
model_components.remove("safety_checker")
295+
variant = "fp16"
296+
297+
with tempfile.TemporaryDirectory() as tmpdir:
298+
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
299+
300+
with open(f"{tmpdir}/model_index.json", "r") as f:
301+
config = json.load(f)
302+
303+
for subfolder in os.listdir(tmpdir):
304+
if not os.path.isfile(subfolder) and subfolder in model_components:
305+
folder_path = os.path.join(tmpdir, subfolder)
306+
is_folder = os.path.isdir(folder_path) and subfolder in config
307+
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
308+
309+
def test_torch_dtype_dict(self):
310+
components = self.get_dummy_components()
311+
if not components:
312+
self.skipTest("No dummy components defined.")
313+
314+
pipe = self.pipeline_class(**components)
315+
316+
specified_key = next(iter(components.keys()))
317+
318+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
319+
pipe.save_pretrained(tmpdirname, safe_serialization=False)
320+
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
321+
loaded_pipe = self.pipeline_class.from_pretrained(
322+
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
323+
)
324+
325+
for name, component in loaded_pipe.components.items():
326+
if name == "safety_checker":
327+
continue
328+
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
329+
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
330+
self.assertEqual(
331+
component.dtype,
332+
expected_dtype,
333+
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
334+
)
335+
336+
@unittest.skip(
337+
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
338+
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
339+
"too large and slow to run on CI."
340+
)
341+
def test_encode_prompt_works_in_isolation(self):
342+
pass

0 commit comments

Comments
 (0)