Skip to content

Commit 06e852d

Browse files
committed
add test for v2w pipeline
1 parent 714f89d commit 06e852d

File tree

2 files changed

+363
-5
lines changed

2 files changed

+363
-5
lines changed

src/diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -650,8 +650,8 @@ def __call__(
650650
unconditioning_latents = conditioning_latents
651651

652652
padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype)
653-
sigma_conditioning = torch.full((batch_size,), sigma_conditioning, dtype=torch.float32, device=device)
654-
sigma_conditioning_t = self.scheduler.precondition_noise(sigma_conditioning)
653+
sigma_conditioning = torch.tensor(sigma_conditioning, dtype=torch.float32, device=device)
654+
t_conditioning = self.scheduler.precondition_noise(sigma_conditioning)
655655

656656
# 6. Denoising loop
657657
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
@@ -663,13 +663,15 @@ def __call__(
663663
continue
664664

665665
self._current_timestep = t
666-
timestep = t.view(1, 1, 1, 1, 1).repeat(latents.size(0), 1, latents.size(2), 1, 1) # [B, 1, T, 1, 1]
666+
timestep = t.view(1, 1, 1, 1, 1).expand(
667+
latents.size(0), -1, latents.size(2), -1, -1
668+
) # [B, 1, T, 1, 1]
667669
current_sigma = self.scheduler.sigmas[i]
668670

669671
cond_latent = self.scheduler.scale_model_input(latents, t)
670672
cond_latent = cond_indicator * conditioning_latents + (1 - cond_indicator) * cond_latent
671673
cond_latent = cond_latent.to(transformer_dtype)
672-
cond_timestep = cond_indicator * sigma_conditioning_t + (1 - cond_indicator) * timestep
674+
cond_timestep = cond_indicator * t_conditioning + (1 - cond_indicator) * timestep
673675
cond_timestep = cond_timestep.to(transformer_dtype)
674676

675677
noise_pred = self.transformer(
@@ -688,7 +690,7 @@ def __call__(
688690
uncond_latent = self.scheduler.scale_model_input(latents, t)
689691
uncond_latent = uncond_indicator * unconditioning_latents + (1 - uncond_indicator) * uncond_latent
690692
uncond_latent = uncond_latent.to(transformer_dtype)
691-
uncond_timestep = uncond_indicator * sigma_conditioning_t + (1 - uncond_indicator) * timestep
693+
uncond_timestep = uncond_indicator * t_conditioning + (1 - uncond_indicator) * timestep
692694
uncond_timestep = uncond_timestep.to(transformer_dtype)
693695

694696
noise_pred_uncond = self.transformer(
Lines changed: 356 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
import json
17+
import os
18+
import tempfile
19+
import unittest
20+
21+
import numpy as np
22+
import PIL.Image
23+
import torch
24+
from transformers import AutoTokenizer, T5EncoderModel
25+
26+
from diffusers import AutoencoderKLWan, Cosmos2VideoToWorldPipeline, CosmosTransformer3DModel, EDMEulerScheduler
27+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
28+
29+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
30+
from ..test_pipelines_common import PipelineTesterMixin, to_np
31+
from .cosmos_guardrail import DummyCosmosSafetyChecker
32+
33+
34+
enable_full_determinism()
35+
36+
37+
class Cosmos2VideoToWorldPipelineWrapper(Cosmos2VideoToWorldPipeline):
38+
@staticmethod
39+
def from_pretrained(*args, **kwargs):
40+
kwargs["safety_checker"] = DummyCosmosSafetyChecker()
41+
return Cosmos2VideoToWorldPipeline.from_pretrained(*args, **kwargs)
42+
43+
44+
class Cosmos2VideoToWorldPipelineFastTests(PipelineTesterMixin, unittest.TestCase):
45+
pipeline_class = Cosmos2VideoToWorldPipelineWrapper
46+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
47+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS.union({"image", "video"})
48+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
49+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
50+
required_optional_params = frozenset(
51+
[
52+
"num_inference_steps",
53+
"generator",
54+
"latents",
55+
"return_dict",
56+
"callback_on_step_end",
57+
"callback_on_step_end_tensor_inputs",
58+
]
59+
)
60+
supports_dduf = False
61+
test_xformers_attention = False
62+
test_layerwise_casting = True
63+
test_group_offloading = True
64+
65+
def get_dummy_components(self):
66+
torch.manual_seed(0)
67+
transformer = CosmosTransformer3DModel(
68+
in_channels=16 + 1,
69+
out_channels=16,
70+
num_attention_heads=2,
71+
attention_head_dim=16,
72+
num_layers=2,
73+
mlp_ratio=2,
74+
text_embed_dim=32,
75+
adaln_lora_dim=4,
76+
max_size=(4, 32, 32),
77+
patch_size=(1, 2, 2),
78+
rope_scale=(2.0, 1.0, 1.0),
79+
concat_padding_mask=True,
80+
extra_pos_embed_type="learnable",
81+
)
82+
83+
torch.manual_seed(0)
84+
vae = AutoencoderKLWan(
85+
base_dim=3,
86+
z_dim=16,
87+
dim_mult=[1, 1, 1, 1],
88+
num_res_blocks=1,
89+
temperal_downsample=[False, True, True],
90+
)
91+
92+
torch.manual_seed(0)
93+
scheduler = EDMEulerScheduler(
94+
sigma_min=0.002,
95+
sigma_max=80,
96+
sigma_data=0.5,
97+
sigma_schedule="karras",
98+
num_train_timesteps=1000,
99+
prediction_type="epsilon",
100+
rho=7.0,
101+
final_sigmas_type="sigma_min",
102+
use_flow_sigmas=True,
103+
)
104+
text_encoder = T5EncoderModel.from_pretrained("hf-internal-testing/tiny-random-t5")
105+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/tiny-random-t5")
106+
107+
components = {
108+
"transformer": transformer,
109+
"vae": vae,
110+
"scheduler": scheduler,
111+
"text_encoder": text_encoder,
112+
"tokenizer": tokenizer,
113+
# We cannot run the Cosmos Guardrail for fast tests due to the large model size
114+
"safety_checker": DummyCosmosSafetyChecker(),
115+
}
116+
return components
117+
118+
def get_dummy_inputs(self, device, seed=0):
119+
if str(device).startswith("mps"):
120+
generator = torch.manual_seed(seed)
121+
else:
122+
generator = torch.Generator(device=device).manual_seed(seed)
123+
124+
image_height = 32
125+
image_width = 32
126+
image = PIL.Image.new("RGB", (image_width, image_height))
127+
128+
inputs = {
129+
"image": image,
130+
"prompt": "dance monkey",
131+
"negative_prompt": "bad quality",
132+
"generator": generator,
133+
"num_inference_steps": 2,
134+
"guidance_scale": 3.0,
135+
"height": image_height,
136+
"width": image_width,
137+
"num_frames": 9,
138+
"max_sequence_length": 16,
139+
"output_type": "pt",
140+
}
141+
142+
return inputs
143+
144+
def test_inference(self):
145+
device = "cpu"
146+
147+
components = self.get_dummy_components()
148+
pipe = self.pipeline_class(**components)
149+
pipe.to(device)
150+
pipe.set_progress_bar_config(disable=None)
151+
152+
inputs = self.get_dummy_inputs(device)
153+
video = pipe(**inputs).frames
154+
generated_video = video[0]
155+
156+
self.assertEqual(generated_video.shape, (9, 3, 32, 32))
157+
expected_video = torch.randn(9, 3, 32, 32)
158+
max_diff = np.abs(generated_video - expected_video).max()
159+
self.assertLessEqual(max_diff, 1e10)
160+
161+
def test_components_function(self):
162+
init_components = self.get_dummy_components()
163+
init_components = {k: v for k, v in init_components.items() if not isinstance(v, (str, int, float))}
164+
pipe = self.pipeline_class(**init_components)
165+
self.assertTrue(hasattr(pipe, "components"))
166+
self.assertTrue(set(pipe.components.keys()) == set(init_components.keys()))
167+
168+
def test_callback_inputs(self):
169+
sig = inspect.signature(self.pipeline_class.__call__)
170+
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
171+
has_callback_step_end = "callback_on_step_end" in sig.parameters
172+
173+
if not (has_callback_tensor_inputs and has_callback_step_end):
174+
return
175+
176+
components = self.get_dummy_components()
177+
pipe = self.pipeline_class(**components)
178+
pipe = pipe.to(torch_device)
179+
pipe.set_progress_bar_config(disable=None)
180+
self.assertTrue(
181+
hasattr(pipe, "_callback_tensor_inputs"),
182+
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
183+
)
184+
185+
def callback_inputs_subset(pipe, i, t, callback_kwargs):
186+
# iterate over callback args
187+
for tensor_name, tensor_value in callback_kwargs.items():
188+
# check that we're only passing in allowed tensor inputs
189+
assert tensor_name in pipe._callback_tensor_inputs
190+
191+
return callback_kwargs
192+
193+
def callback_inputs_all(pipe, i, t, callback_kwargs):
194+
for tensor_name in pipe._callback_tensor_inputs:
195+
assert tensor_name in callback_kwargs
196+
197+
# iterate over callback args
198+
for tensor_name, tensor_value in callback_kwargs.items():
199+
# check that we're only passing in allowed tensor inputs
200+
assert tensor_name in pipe._callback_tensor_inputs
201+
202+
return callback_kwargs
203+
204+
inputs = self.get_dummy_inputs(torch_device)
205+
206+
# Test passing in a subset
207+
inputs["callback_on_step_end"] = callback_inputs_subset
208+
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
209+
output = pipe(**inputs)[0]
210+
211+
# Test passing in a everything
212+
inputs["callback_on_step_end"] = callback_inputs_all
213+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
214+
output = pipe(**inputs)[0]
215+
216+
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
217+
is_last = i == (pipe.num_timesteps - 1)
218+
if is_last:
219+
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
220+
return callback_kwargs
221+
222+
inputs["callback_on_step_end"] = callback_inputs_change_tensor
223+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
224+
output = pipe(**inputs)[0]
225+
assert output.abs().sum() < 1e10
226+
227+
def test_inference_batch_single_identical(self):
228+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-2)
229+
230+
def test_attention_slicing_forward_pass(
231+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
232+
):
233+
if not self.test_attention_slicing:
234+
return
235+
236+
components = self.get_dummy_components()
237+
pipe = self.pipeline_class(**components)
238+
for component in pipe.components.values():
239+
if hasattr(component, "set_default_attn_processor"):
240+
component.set_default_attn_processor()
241+
pipe.to(torch_device)
242+
pipe.set_progress_bar_config(disable=None)
243+
244+
generator_device = "cpu"
245+
inputs = self.get_dummy_inputs(generator_device)
246+
output_without_slicing = pipe(**inputs)[0]
247+
248+
pipe.enable_attention_slicing(slice_size=1)
249+
inputs = self.get_dummy_inputs(generator_device)
250+
output_with_slicing1 = pipe(**inputs)[0]
251+
252+
pipe.enable_attention_slicing(slice_size=2)
253+
inputs = self.get_dummy_inputs(generator_device)
254+
output_with_slicing2 = pipe(**inputs)[0]
255+
256+
if test_max_difference:
257+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
258+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
259+
self.assertLess(
260+
max(max_diff1, max_diff2),
261+
expected_max_diff,
262+
"Attention slicing should not affect the inference results",
263+
)
264+
265+
def test_vae_tiling(self, expected_diff_max: float = 0.2):
266+
generator_device = "cpu"
267+
components = self.get_dummy_components()
268+
269+
pipe = self.pipeline_class(**components)
270+
pipe.to("cpu")
271+
pipe.set_progress_bar_config(disable=None)
272+
273+
# Without tiling
274+
inputs = self.get_dummy_inputs(generator_device)
275+
inputs["height"] = inputs["width"] = 128
276+
output_without_tiling = pipe(**inputs)[0]
277+
278+
# With tiling
279+
pipe.vae.enable_tiling(
280+
tile_sample_min_height=96,
281+
tile_sample_min_width=96,
282+
tile_sample_stride_height=64,
283+
tile_sample_stride_width=64,
284+
)
285+
inputs = self.get_dummy_inputs(generator_device)
286+
inputs["height"] = inputs["width"] = 128
287+
output_with_tiling = pipe(**inputs)[0]
288+
289+
self.assertLess(
290+
(to_np(output_without_tiling) - to_np(output_with_tiling)).max(),
291+
expected_diff_max,
292+
"VAE tiling should not affect the inference results",
293+
)
294+
295+
def test_save_load_optional_components(self, expected_max_difference=1e-4):
296+
self.pipeline_class._optional_components.remove("safety_checker")
297+
super().test_save_load_optional_components(expected_max_difference=expected_max_difference)
298+
self.pipeline_class._optional_components.append("safety_checker")
299+
300+
def test_serialization_with_variants(self):
301+
components = self.get_dummy_components()
302+
pipe = self.pipeline_class(**components)
303+
model_components = [
304+
component_name
305+
for component_name, component in pipe.components.items()
306+
if isinstance(component, torch.nn.Module)
307+
]
308+
model_components.remove("safety_checker")
309+
variant = "fp16"
310+
311+
with tempfile.TemporaryDirectory() as tmpdir:
312+
pipe.save_pretrained(tmpdir, variant=variant, safe_serialization=False)
313+
314+
with open(f"{tmpdir}/model_index.json", "r") as f:
315+
config = json.load(f)
316+
317+
for subfolder in os.listdir(tmpdir):
318+
if not os.path.isfile(subfolder) and subfolder in model_components:
319+
folder_path = os.path.join(tmpdir, subfolder)
320+
is_folder = os.path.isdir(folder_path) and subfolder in config
321+
assert is_folder and any(p.split(".")[1].startswith(variant) for p in os.listdir(folder_path))
322+
323+
def test_torch_dtype_dict(self):
324+
components = self.get_dummy_components()
325+
if not components:
326+
self.skipTest("No dummy components defined.")
327+
328+
pipe = self.pipeline_class(**components)
329+
330+
specified_key = next(iter(components.keys()))
331+
332+
with tempfile.TemporaryDirectory(ignore_cleanup_errors=True) as tmpdirname:
333+
pipe.save_pretrained(tmpdirname, safe_serialization=False)
334+
torch_dtype_dict = {specified_key: torch.bfloat16, "default": torch.float16}
335+
loaded_pipe = self.pipeline_class.from_pretrained(
336+
tmpdirname, safety_checker=DummyCosmosSafetyChecker(), torch_dtype=torch_dtype_dict
337+
)
338+
339+
for name, component in loaded_pipe.components.items():
340+
if name == "safety_checker":
341+
continue
342+
if isinstance(component, torch.nn.Module) and hasattr(component, "dtype"):
343+
expected_dtype = torch_dtype_dict.get(name, torch_dtype_dict.get("default", torch.float32))
344+
self.assertEqual(
345+
component.dtype,
346+
expected_dtype,
347+
f"Component '{name}' has dtype {component.dtype} but expected {expected_dtype}",
348+
)
349+
350+
@unittest.skip(
351+
"The pipeline should not be runnable without a safety checker. The test creates a pipeline without passing in "
352+
"a safety checker, which makes the pipeline default to the actual Cosmos Guardrail. The Cosmos Guardrail is "
353+
"too large and slow to run on CI."
354+
)
355+
def test_encode_prompt_works_in_isolation(self):
356+
pass

0 commit comments

Comments
 (0)