Skip to content

Commit 3ae0ee8

Browse files
a-r-r-o-wsayakpaulDN6
authored
[tests] speed up animatediff tests (#8846)
* speed up animatediff tests * fix pia test_ip_adapter_single * fix tests/pipelines/pia/test_pia.py::PIAPipelineFastTests::test_dict_tuple_outputs_equivalent * update * fix ip adapter tests * skip test_from_pipe_consistent_config tests * fix prompt_embeds test * update test_from_pipe_consistent_config tests * fix expected_slice values * remove temporal_norm_num_groups from UpBlockMotion --------- Co-authored-by: Sayak Paul <[email protected]> Co-authored-by: Dhruv Nair <[email protected]>
1 parent 5fbb4d3 commit 3ae0ee8

File tree

4 files changed

+205
-97
lines changed

4 files changed

+205
-97
lines changed

src/diffusers/models/unets/unet_3d_blocks.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1532,7 +1532,6 @@ def __init__(
15321532
resnet_pre_norm: bool = True,
15331533
output_scale_factor: float = 1.0,
15341534
add_upsample: bool = True,
1535-
temporal_norm_num_groups: int = 32,
15361535
temporal_cross_attention_dim: Optional[int] = None,
15371536
temporal_num_attention_heads: int = 8,
15381537
temporal_max_seq_length: int = 32,
@@ -1574,7 +1573,7 @@ def __init__(
15741573
num_attention_heads=temporal_num_attention_heads,
15751574
in_channels=out_channels,
15761575
num_layers=temporal_transformer_layers_per_block[i],
1577-
norm_num_groups=temporal_norm_num_groups,
1576+
norm_num_groups=resnet_groups,
15781577
cross_attention_dim=temporal_cross_attention_dim,
15791578
attention_bias=False,
15801579
activation_fn="geglu",

tests/pipelines/animatediff/test_animatediff.py

Lines changed: 71 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AutoencoderKL,
1212
DDIMScheduler,
1313
MotionAdapter,
14+
StableDiffusionPipeline,
1415
UNet2DConditionModel,
1516
UNetMotionModel,
1617
)
@@ -51,16 +52,19 @@ class AnimateDiffPipelineFastTests(
5152
)
5253

5354
def get_dummy_components(self):
55+
cross_attention_dim = 8
56+
block_out_channels = (8, 8)
57+
5458
torch.manual_seed(0)
5559
unet = UNet2DConditionModel(
56-
block_out_channels=(32, 64),
60+
block_out_channels=block_out_channels,
5761
layers_per_block=2,
58-
sample_size=32,
62+
sample_size=8,
5963
in_channels=4,
6064
out_channels=4,
6165
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
6266
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
63-
cross_attention_dim=32,
67+
cross_attention_dim=cross_attention_dim,
6468
norm_num_groups=2,
6569
)
6670
scheduler = DDIMScheduler(
@@ -71,18 +75,19 @@ def get_dummy_components(self):
7175
)
7276
torch.manual_seed(0)
7377
vae = AutoencoderKL(
74-
block_out_channels=[32, 64],
78+
block_out_channels=block_out_channels,
7579
in_channels=3,
7680
out_channels=3,
7781
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
7882
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
7983
latent_channels=4,
84+
norm_num_groups=2,
8085
)
8186
torch.manual_seed(0)
8287
text_encoder_config = CLIPTextConfig(
8388
bos_token_id=0,
8489
eos_token_id=2,
85-
hidden_size=32,
90+
hidden_size=cross_attention_dim,
8691
intermediate_size=37,
8792
layer_norm_eps=1e-05,
8893
num_attention_heads=4,
@@ -92,8 +97,9 @@ def get_dummy_components(self):
9297
)
9398
text_encoder = CLIPTextModel(text_encoder_config)
9499
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
100+
torch.manual_seed(0)
95101
motion_adapter = MotionAdapter(
96-
block_out_channels=(32, 64),
102+
block_out_channels=block_out_channels,
97103
motion_layers_per_block=2,
98104
motion_norm_num_groups=2,
99105
motion_num_attention_heads=4,
@@ -126,6 +132,36 @@ def get_dummy_inputs(self, device, seed=0):
126132
}
127133
return inputs
128134

135+
def test_from_pipe_consistent_config(self):
136+
assert self.original_pipeline_class == StableDiffusionPipeline
137+
original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
138+
original_kwargs = {"requires_safety_checker": False}
139+
140+
# create original_pipeline_class(sd)
141+
pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
142+
143+
# original_pipeline_class(sd) -> pipeline_class
144+
pipe_components = self.get_dummy_components()
145+
pipe_additional_components = {}
146+
for name, component in pipe_components.items():
147+
if name not in pipe_original.components:
148+
pipe_additional_components[name] = component
149+
150+
pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
151+
152+
# pipeline_class -> original_pipeline_class(sd)
153+
original_pipe_additional_components = {}
154+
for name, component in pipe_original.components.items():
155+
if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
156+
original_pipe_additional_components[name] = component
157+
158+
pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
159+
160+
# compare the config
161+
original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
162+
original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
163+
assert original_config_2 == original_config
164+
129165
def test_motion_unet_loading(self):
130166
components = self.get_dummy_components()
131167
pipe = AnimateDiffPipeline(**components)
@@ -141,41 +177,41 @@ def test_ip_adapter_single(self):
141177
if torch_device == "cpu":
142178
expected_pipe_slice = np.array(
143179
[
144-
0.5541,
145-
0.5802,
146-
0.5074,
147-
0.4583,
148-
0.4729,
149-
0.5374,
150-
0.4051,
151-
0.4495,
152-
0.4480,
153-
0.5292,
154-
0.6322,
155-
0.6265,
156-
0.5455,
157-
0.4771,
158-
0.5795,
159-
0.5845,
160-
0.4172,
161-
0.6066,
162-
0.6535,
163-
0.4113,
164-
0.6833,
165-
0.5736,
166-
0.3589,
167-
0.5730,
168-
0.4205,
169-
0.3786,
170-
0.5323,
180+
0.5216,
181+
0.5620,
182+
0.4927,
183+
0.5082,
184+
0.4786,
185+
0.5932,
186+
0.5125,
187+
0.4514,
188+
0.5315,
189+
0.4694,
190+
0.3276,
191+
0.4863,
192+
0.3920,
193+
0.3684,
194+
0.5745,
195+
0.4499,
196+
0.5081,
197+
0.5414,
198+
0.6014,
199+
0.5062,
200+
0.3630,
201+
0.5296,
202+
0.6018,
203+
0.5098,
204+
0.4948,
205+
0.5101,
206+
0.5620,
171207
]
172208
)
173209
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
174210

175211
def test_dict_tuple_outputs_equivalent(self):
176212
expected_slice = None
177213
if torch_device == "cpu":
178-
expected_slice = np.array([0.4051, 0.4495, 0.4480, 0.5845, 0.4172, 0.6066, 0.4205, 0.3786, 0.5323])
214+
expected_slice = np.array([0.5125, 0.4514, 0.5315, 0.4499, 0.5081, 0.5414, 0.4948, 0.5101, 0.5620])
179215
return super().test_dict_tuple_outputs_equivalent(expected_slice=expected_slice)
180216

181217
def test_inference_batch_single_identical(
@@ -279,7 +315,7 @@ def test_prompt_embeds(self):
279315

280316
inputs = self.get_dummy_inputs(torch_device)
281317
inputs.pop("prompt")
282-
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
318+
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
283319
pipe(**inputs)
284320

285321
def test_free_init(self):

tests/pipelines/animatediff/test_animatediff_video2video.py

Lines changed: 62 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
AutoencoderKL,
1212
DDIMScheduler,
1313
MotionAdapter,
14+
StableDiffusionPipeline,
1415
UNet2DConditionModel,
1516
UNetMotionModel,
1617
)
@@ -46,16 +47,19 @@ class AnimateDiffVideoToVideoPipelineFastTests(
4647
)
4748

4849
def get_dummy_components(self):
50+
cross_attention_dim = 8
51+
block_out_channels = (8, 8)
52+
4953
torch.manual_seed(0)
5054
unet = UNet2DConditionModel(
51-
block_out_channels=(32, 64),
55+
block_out_channels=block_out_channels,
5256
layers_per_block=2,
53-
sample_size=32,
57+
sample_size=8,
5458
in_channels=4,
5559
out_channels=4,
5660
down_block_types=("CrossAttnDownBlock2D", "DownBlock2D"),
5761
up_block_types=("CrossAttnUpBlock2D", "UpBlock2D"),
58-
cross_attention_dim=32,
62+
cross_attention_dim=cross_attention_dim,
5963
norm_num_groups=2,
6064
)
6165
scheduler = DDIMScheduler(
@@ -66,18 +70,19 @@ def get_dummy_components(self):
6670
)
6771
torch.manual_seed(0)
6872
vae = AutoencoderKL(
69-
block_out_channels=[32, 64],
73+
block_out_channels=block_out_channels,
7074
in_channels=3,
7175
out_channels=3,
7276
down_block_types=["DownEncoderBlock2D", "DownEncoderBlock2D"],
7377
up_block_types=["UpDecoderBlock2D", "UpDecoderBlock2D"],
7478
latent_channels=4,
79+
norm_num_groups=2,
7580
)
7681
torch.manual_seed(0)
7782
text_encoder_config = CLIPTextConfig(
7883
bos_token_id=0,
7984
eos_token_id=2,
80-
hidden_size=32,
85+
hidden_size=cross_attention_dim,
8186
intermediate_size=37,
8287
layer_norm_eps=1e-05,
8388
num_attention_heads=4,
@@ -87,8 +92,9 @@ def get_dummy_components(self):
8792
)
8893
text_encoder = CLIPTextModel(text_encoder_config)
8994
tokenizer = CLIPTokenizer.from_pretrained("hf-internal-testing/tiny-random-clip")
95+
torch.manual_seed(0)
9096
motion_adapter = MotionAdapter(
91-
block_out_channels=(32, 64),
97+
block_out_channels=block_out_channels,
9298
motion_layers_per_block=2,
9399
motion_norm_num_groups=2,
94100
motion_num_attention_heads=4,
@@ -127,6 +133,36 @@ def get_dummy_inputs(self, device, seed=0):
127133
}
128134
return inputs
129135

136+
def test_from_pipe_consistent_config(self):
137+
assert self.original_pipeline_class == StableDiffusionPipeline
138+
original_repo = "hf-internal-testing/tinier-stable-diffusion-pipe"
139+
original_kwargs = {"requires_safety_checker": False}
140+
141+
# create original_pipeline_class(sd)
142+
pipe_original = self.original_pipeline_class.from_pretrained(original_repo, **original_kwargs)
143+
144+
# original_pipeline_class(sd) -> pipeline_class
145+
pipe_components = self.get_dummy_components()
146+
pipe_additional_components = {}
147+
for name, component in pipe_components.items():
148+
if name not in pipe_original.components:
149+
pipe_additional_components[name] = component
150+
151+
pipe = self.pipeline_class.from_pipe(pipe_original, **pipe_additional_components)
152+
153+
# pipeline_class -> original_pipeline_class(sd)
154+
original_pipe_additional_components = {}
155+
for name, component in pipe_original.components.items():
156+
if name not in pipe.components or not isinstance(component, pipe.components[name].__class__):
157+
original_pipe_additional_components[name] = component
158+
159+
pipe_original_2 = self.original_pipeline_class.from_pipe(pipe, **original_pipe_additional_components)
160+
161+
# compare the config
162+
original_config = {k: v for k, v in pipe_original.config.items() if not k.startswith("_")}
163+
original_config_2 = {k: v for k, v in pipe_original_2.config.items() if not k.startswith("_")}
164+
assert original_config_2 == original_config
165+
130166
def test_motion_unet_loading(self):
131167
components = self.get_dummy_components()
132168
pipe = AnimateDiffVideoToVideoPipeline(**components)
@@ -143,24 +179,24 @@ def test_ip_adapter_single(self):
143179
if torch_device == "cpu":
144180
expected_pipe_slice = np.array(
145181
[
146-
0.4947,
147-
0.4780,
148-
0.4340,
149-
0.4666,
150-
0.4028,
151-
0.4645,
152-
0.4915,
153-
0.4101,
154-
0.4308,
155-
0.4581,
156-
0.3582,
157-
0.4953,
158-
0.4466,
159-
0.5348,
160-
0.5863,
161-
0.5299,
182+
0.5569,
183+
0.6250,
184+
0.4145,
185+
0.5613,
186+
0.5563,
162187
0.5213,
163-
0.5017,
188+
0.5092,
189+
0.4950,
190+
0.4950,
191+
0.5685,
192+
0.3858,
193+
0.4864,
194+
0.6458,
195+
0.4312,
196+
0.5518,
197+
0.5608,
198+
0.4418,
199+
0.5378,
164200
]
165201
)
166202
return super().test_ip_adapter_single(expected_pipe_slice=expected_pipe_slice)
@@ -266,7 +302,7 @@ def test_prompt_embeds(self):
266302

267303
inputs = self.get_dummy_inputs(torch_device)
268304
inputs.pop("prompt")
269-
inputs["prompt_embeds"] = torch.randn((1, 4, 32), device=torch_device)
305+
inputs["prompt_embeds"] = torch.randn((1, 4, pipe.text_encoder.config.hidden_size), device=torch_device)
270306
pipe(**inputs)
271307

272308
def test_latent_inputs(self):
@@ -276,7 +312,8 @@ def test_latent_inputs(self):
276312
pipe.to(torch_device)
277313

278314
inputs = self.get_dummy_inputs(torch_device)
279-
inputs["latents"] = torch.randn((1, 4, 1, 32, 32), device=torch_device)
315+
sample_size = pipe.unet.config.sample_size
316+
inputs["latents"] = torch.randn((1, 4, 1, sample_size, sample_size), device=torch_device)
280317
inputs.pop("video")
281318
pipe(**inputs)
282319

0 commit comments

Comments
 (0)