Skip to content

Commit 9d55d0a

Browse files
committed
add tests
1 parent c1b8004 commit 9d55d0a

File tree

5 files changed

+315
-2
lines changed

5 files changed

+315
-2
lines changed

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,7 @@
497497
CogVideoXVideoToVideoPipeline,
498498
)
499499
from .cogview3 import CogView3PlusPipeline
500+
from .cogview4 import CogView4Pipeline
500501
from .consisid import ConsisIDPipeline
501502
from .controlnet import (
502503
BlipDiffusionControlNetPipeline,

src/diffusers/pipelines/cogview4/pipeline_cogview4.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from typing import Callable, Dict, List, Optional, Tuple, Union
1717

1818
import torch
19-
from transformers import GlmModel
19+
from transformers import AutoTokenizer, GlmModel
2020

2121
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2222
from ...image_processor import VaeImageProcessor
@@ -82,7 +82,7 @@ class CogView4Pipeline(DiffusionPipeline):
8282

8383
def __init__(
8484
self,
85-
tokenizer: GlmModel,
85+
tokenizer: AutoTokenizer,
8686
text_encoder: GlmModel,
8787
vae: AutoencoderKL,
8888
transformer: CogView4Transformer2DModel,
@@ -493,6 +493,7 @@ def __call__(
493493
)
494494
self.scheduler.set_timesteps(num_inference_steps, image_seq_len, device)
495495
timesteps = self.scheduler.timesteps
496+
self._num_timesteps = len(timesteps)
496497

497498
# Denoising loop
498499
transformer_dtype = self.transformer.dtype
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright 2024 HuggingFace Inc.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import torch
18+
19+
from diffusers import CogView4Transformer2DModel
20+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
21+
22+
from ..test_modeling_common import ModelTesterMixin
23+
24+
25+
enable_full_determinism()
26+
27+
28+
class CogView3PlusTransformerTests(ModelTesterMixin, unittest.TestCase):
29+
model_class = CogView4Transformer2DModel
30+
main_input_name = "hidden_states"
31+
uses_custom_attn_processor = True
32+
33+
@property
34+
def dummy_input(self):
35+
batch_size = 2
36+
num_channels = 4
37+
height = 8
38+
width = 8
39+
embedding_dim = 8
40+
sequence_length = 8
41+
42+
hidden_states = torch.randn((batch_size, num_channels, height, width)).to(torch_device)
43+
encoder_hidden_states = torch.randn((batch_size, sequence_length, embedding_dim)).to(torch_device)
44+
original_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
45+
target_size = torch.tensor([height * 8, width * 8]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
46+
crop_coords = torch.tensor([0, 0]).unsqueeze(0).repeat(batch_size, 1).to(torch_device)
47+
timestep = torch.randint(0, 1000, size=(batch_size,)).to(torch_device)
48+
49+
return {
50+
"hidden_states": hidden_states,
51+
"encoder_hidden_states": encoder_hidden_states,
52+
"timestep": timestep,
53+
"original_size": original_size,
54+
"target_size": target_size,
55+
"crop_coords": crop_coords,
56+
}
57+
58+
@property
59+
def input_shape(self):
60+
return (4, 8, 8)
61+
62+
@property
63+
def output_shape(self):
64+
return (4, 8, 8)
65+
66+
def prepare_init_args_and_inputs_for_common(self):
67+
init_dict = {
68+
"patch_size": 2,
69+
"in_channels": 4,
70+
"num_layers": 2,
71+
"attention_head_dim": 4,
72+
"num_attention_heads": 4,
73+
"out_channels": 4,
74+
"text_embed_dim": 8,
75+
"time_embed_dim": 8,
76+
"condition_dim": 4,
77+
}
78+
inputs_dict = self.dummy_input
79+
return init_dict, inputs_dict
80+
81+
def test_gradient_checkpointing_is_applied(self):
82+
expected_set = {"CogView4Transformer2DModel"}
83+
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

tests/pipelines/cogview4/__init__.py

Whitespace-only changes.
Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright 2024 The HuggingFace Team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import inspect
16+
import unittest
17+
18+
import numpy as np
19+
import torch
20+
from transformers import AutoTokenizer, GlmConfig, GlmForCausalLM
21+
22+
from diffusers import AutoencoderKL, CogView4DDIMScheduler, CogView4Pipeline, CogView4Transformer2DModel
23+
from diffusers.utils.testing_utils import enable_full_determinism, torch_device
24+
25+
from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS
26+
from ..test_pipelines_common import PipelineTesterMixin, to_np
27+
28+
29+
enable_full_determinism()
30+
31+
32+
class CogView4PipelineFastTests(PipelineTesterMixin, unittest.TestCase):
33+
pipeline_class = CogView4Pipeline
34+
params = TEXT_TO_IMAGE_PARAMS - {"cross_attention_kwargs"}
35+
batch_params = TEXT_TO_IMAGE_BATCH_PARAMS
36+
image_params = TEXT_TO_IMAGE_IMAGE_PARAMS
37+
image_latents_params = TEXT_TO_IMAGE_IMAGE_PARAMS
38+
required_optional_params = frozenset(
39+
[
40+
"num_inference_steps",
41+
"generator",
42+
"latents",
43+
"return_dict",
44+
"callback_on_step_end",
45+
"callback_on_step_end_tensor_inputs",
46+
]
47+
)
48+
49+
supports_dduf = False
50+
test_xformers_attention = False
51+
test_layerwise_casting = True
52+
53+
def get_dummy_components(self):
54+
torch.manual_seed(0)
55+
transformer = CogView4Transformer2DModel(
56+
patch_size=2,
57+
in_channels=4,
58+
num_layers=2,
59+
attention_head_dim=4,
60+
num_attention_heads=4,
61+
out_channels=4,
62+
text_embed_dim=32,
63+
time_embed_dim=8,
64+
condition_dim=4,
65+
)
66+
67+
torch.manual_seed(0)
68+
vae = AutoencoderKL(
69+
block_out_channels=[32, 64],
70+
in_channels=3,
71+
out_channels=3,
72+
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
73+
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
74+
latent_channels=4,
75+
sample_size=128,
76+
)
77+
78+
torch.manual_seed(0)
79+
scheduler = CogView4DDIMScheduler()
80+
81+
torch.manual_seed(0)
82+
text_encoder_config = GlmConfig(
83+
hidden_size=32, intermediate_size=8, num_hidden_layers=2, num_attention_heads=4, head_dim=8
84+
)
85+
text_encoder = GlmForCausalLM(text_encoder_config)
86+
# TODO(aryan): change this to THUDM/CogView4 once released
87+
tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-4-9b-chat", trust_remote_code=True)
88+
89+
components = {
90+
"transformer": transformer,
91+
"vae": vae,
92+
"scheduler": scheduler,
93+
"text_encoder": text_encoder,
94+
"tokenizer": tokenizer,
95+
}
96+
return components
97+
98+
def get_dummy_inputs(self, device, seed=0):
99+
if str(device).startswith("mps"):
100+
generator = torch.manual_seed(seed)
101+
else:
102+
generator = torch.Generator(device=device).manual_seed(seed)
103+
inputs = {
104+
"prompt": "dance monkey",
105+
"negative_prompt": "",
106+
"generator": generator,
107+
"num_inference_steps": 2,
108+
"guidance_scale": 6.0,
109+
"height": 16,
110+
"width": 16,
111+
"max_sequence_length": 16,
112+
"output_type": "pt",
113+
}
114+
return inputs
115+
116+
def test_inference(self):
117+
device = "cpu"
118+
119+
components = self.get_dummy_components()
120+
pipe = self.pipeline_class(**components)
121+
pipe.to(device)
122+
pipe.set_progress_bar_config(disable=None)
123+
124+
inputs = self.get_dummy_inputs(device)
125+
image = pipe(**inputs)[0]
126+
generated_image = image[0]
127+
128+
self.assertEqual(generated_image.shape, (3, 16, 16))
129+
expected_image = torch.randn(3, 16, 16)
130+
max_diff = np.abs(generated_image - expected_image).max()
131+
self.assertLessEqual(max_diff, 1e10)
132+
133+
def test_callback_inputs(self):
134+
sig = inspect.signature(self.pipeline_class.__call__)
135+
has_callback_tensor_inputs = "callback_on_step_end_tensor_inputs" in sig.parameters
136+
has_callback_step_end = "callback_on_step_end" in sig.parameters
137+
138+
if not (has_callback_tensor_inputs and has_callback_step_end):
139+
return
140+
141+
components = self.get_dummy_components()
142+
pipe = self.pipeline_class(**components)
143+
pipe = pipe.to(torch_device)
144+
pipe.set_progress_bar_config(disable=None)
145+
self.assertTrue(
146+
hasattr(pipe, "_callback_tensor_inputs"),
147+
f" {self.pipeline_class} should have `_callback_tensor_inputs` that defines a list of tensor variables its callback function can use as inputs",
148+
)
149+
150+
def callback_inputs_subset(pipe, i, t, callback_kwargs):
151+
# iterate over callback args
152+
for tensor_name, tensor_value in callback_kwargs.items():
153+
# check that we're only passing in allowed tensor inputs
154+
assert tensor_name in pipe._callback_tensor_inputs
155+
156+
return callback_kwargs
157+
158+
def callback_inputs_all(pipe, i, t, callback_kwargs):
159+
for tensor_name in pipe._callback_tensor_inputs:
160+
assert tensor_name in callback_kwargs
161+
162+
# iterate over callback args
163+
for tensor_name, tensor_value in callback_kwargs.items():
164+
# check that we're only passing in allowed tensor inputs
165+
assert tensor_name in pipe._callback_tensor_inputs
166+
167+
return callback_kwargs
168+
169+
inputs = self.get_dummy_inputs(torch_device)
170+
171+
# Test passing in a subset
172+
inputs["callback_on_step_end"] = callback_inputs_subset
173+
inputs["callback_on_step_end_tensor_inputs"] = ["latents"]
174+
output = pipe(**inputs)[0]
175+
176+
# Test passing in a everything
177+
inputs["callback_on_step_end"] = callback_inputs_all
178+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
179+
output = pipe(**inputs)[0]
180+
181+
def callback_inputs_change_tensor(pipe, i, t, callback_kwargs):
182+
is_last = i == (pipe.num_timesteps - 1)
183+
if is_last:
184+
callback_kwargs["latents"] = torch.zeros_like(callback_kwargs["latents"])
185+
return callback_kwargs
186+
187+
inputs["callback_on_step_end"] = callback_inputs_change_tensor
188+
inputs["callback_on_step_end_tensor_inputs"] = pipe._callback_tensor_inputs
189+
output = pipe(**inputs)[0]
190+
assert output.abs().sum() < 1e10
191+
192+
def test_inference_batch_single_identical(self):
193+
self._test_inference_batch_single_identical(batch_size=3, expected_max_diff=1e-3)
194+
195+
def test_attention_slicing_forward_pass(
196+
self, test_max_difference=True, test_mean_pixel_difference=True, expected_max_diff=1e-3
197+
):
198+
if not self.test_attention_slicing:
199+
return
200+
201+
components = self.get_dummy_components()
202+
pipe = self.pipeline_class(**components)
203+
for component in pipe.components.values():
204+
if hasattr(component, "set_default_attn_processor"):
205+
component.set_default_attn_processor()
206+
pipe.to(torch_device)
207+
pipe.set_progress_bar_config(disable=None)
208+
209+
generator_device = "cpu"
210+
inputs = self.get_dummy_inputs(generator_device)
211+
output_without_slicing = pipe(**inputs)[0]
212+
213+
pipe.enable_attention_slicing(slice_size=1)
214+
inputs = self.get_dummy_inputs(generator_device)
215+
output_with_slicing1 = pipe(**inputs)[0]
216+
217+
pipe.enable_attention_slicing(slice_size=2)
218+
inputs = self.get_dummy_inputs(generator_device)
219+
output_with_slicing2 = pipe(**inputs)[0]
220+
221+
if test_max_difference:
222+
max_diff1 = np.abs(to_np(output_with_slicing1) - to_np(output_without_slicing)).max()
223+
max_diff2 = np.abs(to_np(output_with_slicing2) - to_np(output_without_slicing)).max()
224+
self.assertLess(
225+
max(max_diff1, max_diff2),
226+
expected_max_diff,
227+
"Attention slicing should not affect the inference results",
228+
)

0 commit comments

Comments
 (0)