Skip to content

Commit 1c60e09

Browse files
authored
[Tests] reduce block sizes of UNet and VAE tests (#7560)
* reduce block sizes for unet1d. * reduce blocks for unet_2d. * reduce block size for unet_motion * increase channels. * correctly increase channels. * reduce number of layers in unet2dconditionmodel tests. * reduce block sizes for unet2dconditionmodel tests * reduce block sizes for unet3dconditionmodel. * fix: test_feed_forward_chunking * fix: test_forward_with_norm_groups * skip spatiotemporal tests on MPS. * reduce block size in AutoencoderKL. * reduce block sizes for vqmodel. * further reduce block size. * make style. * Empty-Commit * reduce sizes for ConsistencyDecoderVAETests * further reduction. * further block reductions in AutoencoderKL and AssymetricAutoencoderKL. * massively reduce the block size in unet2dcontionmodel. * reduce sizes for unet3d * fix tests in unet3d. * reduce blocks further in motion unet. * fix: output shape * add attention_head_dim to the test configuration. * remove unexpected keyword arg * up a bit. * groups. * up again * fix
1 parent 71f49a5 commit 1c60e09

8 files changed

+61
-41
lines changed

tests/models/autoencoders/test_models_vae.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,8 @@
5353

5454

5555
def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
56-
block_out_channels = block_out_channels or [32, 64]
57-
norm_num_groups = norm_num_groups or 32
56+
block_out_channels = block_out_channels or [2, 4]
57+
norm_num_groups = norm_num_groups or 2
5858
init_dict = {
5959
"block_out_channels": block_out_channels,
6060
"in_channels": 3,
@@ -68,8 +68,8 @@ def get_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
6868

6969

7070
def get_asym_autoencoder_kl_config(block_out_channels=None, norm_num_groups=None):
71-
block_out_channels = block_out_channels or [32, 64]
72-
norm_num_groups = norm_num_groups or 32
71+
block_out_channels = block_out_channels or [2, 4]
72+
norm_num_groups = norm_num_groups or 2
7373
init_dict = {
7474
"in_channels": 3,
7575
"out_channels": 3,
@@ -102,8 +102,8 @@ def get_autoencoder_tiny_config(block_out_channels=None):
102102

103103

104104
def get_consistency_vae_config(block_out_channels=None, norm_num_groups=None):
105-
block_out_channels = block_out_channels or [32, 64]
106-
norm_num_groups = norm_num_groups or 32
105+
block_out_channels = block_out_channels or [2, 4]
106+
norm_num_groups = norm_num_groups or 2
107107
return {
108108
"encoder_block_out_channels": block_out_channels,
109109
"encoder_in_channels": 3,

tests/models/autoencoders/test_models_vq.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,8 @@ def output_shape(self):
5454

5555
def prepare_init_args_and_inputs_for_common(self):
5656
init_dict = {
57-
"block_out_channels": [32, 64],
57+
"block_out_channels": [8, 16],
58+
"norm_num_groups": 8,
5859
"in_channels": 3,
5960
"out_channels": 3,
6061
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],

tests/models/unets/test_models_unet_1d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def test_output(self):
7777

7878
def prepare_init_args_and_inputs_for_common(self):
7979
init_dict = {
80-
"block_out_channels": (32, 64, 128, 256),
80+
"block_out_channels": (8, 8, 16, 16),
8181
"in_channels": 14,
8282
"out_channels": 14,
8383
"time_embedding_type": "positional",

tests/models/unets/test_models_unet_2d.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,8 @@ def output_shape(self):
6363

6464
def prepare_init_args_and_inputs_for_common(self):
6565
init_dict = {
66-
"block_out_channels": (32, 64),
66+
"block_out_channels": (4, 8),
67+
"norm_num_groups": 2,
6768
"down_block_types": ("DownBlock2D", "AttnDownBlock2D"),
6869
"up_block_types": ("AttnUpBlock2D", "UpBlock2D"),
6970
"attention_head_dim": 3,
@@ -78,9 +79,8 @@ def prepare_init_args_and_inputs_for_common(self):
7879
def test_mid_block_attn_groups(self):
7980
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
8081

81-
init_dict["norm_num_groups"] = 16
8282
init_dict["add_attention"] = True
83-
init_dict["attn_norm_num_groups"] = 8
83+
init_dict["attn_norm_num_groups"] = 4
8484

8585
model = self.model_class(**init_dict)
8686
model.to(torch_device)

tests/models/unets/test_models_unet_2d_condition.py

Lines changed: 24 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -247,33 +247,34 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
247247
def dummy_input(self):
248248
batch_size = 4
249249
num_channels = 4
250-
sizes = (32, 32)
250+
sizes = (16, 16)
251251

252252
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
253253
time_step = torch.tensor([10]).to(torch_device)
254-
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
254+
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
255255

256256
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
257257

258258
@property
259259
def input_shape(self):
260-
return (4, 32, 32)
260+
return (4, 16, 16)
261261

262262
@property
263263
def output_shape(self):
264-
return (4, 32, 32)
264+
return (4, 16, 16)
265265

266266
def prepare_init_args_and_inputs_for_common(self):
267267
init_dict = {
268-
"block_out_channels": (32, 64),
268+
"block_out_channels": (4, 8),
269+
"norm_num_groups": 4,
269270
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
270271
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
271-
"cross_attention_dim": 32,
272-
"attention_head_dim": 8,
272+
"cross_attention_dim": 8,
273+
"attention_head_dim": 2,
273274
"out_channels": 4,
274275
"in_channels": 4,
275-
"layers_per_block": 2,
276-
"sample_size": 32,
276+
"layers_per_block": 1,
277+
"sample_size": 16,
277278
}
278279
inputs_dict = self.dummy_input
279280
return init_dict, inputs_dict
@@ -337,6 +338,7 @@ def test_gradient_checkpointing(self):
337338
def test_model_with_attention_head_dim_tuple(self):
338339
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
339340

341+
init_dict["block_out_channels"] = (16, 32)
340342
init_dict["attention_head_dim"] = (8, 16)
341343

342344
model = self.model_class(**init_dict)
@@ -375,7 +377,7 @@ def test_model_with_use_linear_projection(self):
375377
def test_model_with_cross_attention_dim_tuple(self):
376378
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
377379

378-
init_dict["cross_attention_dim"] = (32, 32)
380+
init_dict["cross_attention_dim"] = (8, 8)
379381

380382
model = self.model_class(**init_dict)
381383
model.to(torch_device)
@@ -443,6 +445,7 @@ def test_model_with_class_embeddings_concat(self):
443445
def test_model_attention_slicing(self):
444446
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
445447

448+
init_dict["block_out_channels"] = (16, 32)
446449
init_dict["attention_head_dim"] = (8, 16)
447450

448451
model = self.model_class(**init_dict)
@@ -467,6 +470,7 @@ def test_model_attention_slicing(self):
467470
def test_model_sliceable_head_dim(self):
468471
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
469472

473+
init_dict["block_out_channels"] = (16, 32)
470474
init_dict["attention_head_dim"] = (8, 16)
471475

472476
model = self.model_class(**init_dict)
@@ -485,6 +489,7 @@ def check_sliceable_dim_attr(module: torch.nn.Module):
485489
def test_gradient_checkpointing_is_applied(self):
486490
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
487491

492+
init_dict["block_out_channels"] = (16, 32)
488493
init_dict["attention_head_dim"] = (8, 16)
489494

490495
model_class_copy = copy.copy(self.model_class)
@@ -561,6 +566,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
561566
# enable deterministic behavior for gradient checkpointing
562567
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
563568

569+
init_dict["block_out_channels"] = (16, 32)
564570
init_dict["attention_head_dim"] = (8, 16)
565571

566572
model = self.model_class(**init_dict)
@@ -571,7 +577,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
571577
model.set_attn_processor(processor)
572578
model(**inputs_dict, cross_attention_kwargs={"number": 123}).sample
573579

574-
assert processor.counter == 12
580+
assert processor.counter == 8
575581
assert processor.is_run
576582
assert processor.number == 123
577583

@@ -587,7 +593,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
587593
def test_model_xattn_mask(self, mask_dtype):
588594
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
589595

590-
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16)})
596+
model = self.model_class(**{**init_dict, "attention_head_dim": (8, 16), "block_out_channels": (16, 32)})
591597
model.to(torch_device)
592598
model.eval()
593599

@@ -649,6 +655,7 @@ def test_custom_diffusion_processors(self):
649655
# enable deterministic behavior for gradient checkpointing
650656
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
651657

658+
init_dict["block_out_channels"] = (16, 32)
652659
init_dict["attention_head_dim"] = (8, 16)
653660

654661
model = self.model_class(**init_dict)
@@ -675,6 +682,7 @@ def test_custom_diffusion_save_load(self):
675682
# enable deterministic behavior for gradient checkpointing
676683
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
677684

685+
init_dict["block_out_channels"] = (16, 32)
678686
init_dict["attention_head_dim"] = (8, 16)
679687

680688
torch.manual_seed(0)
@@ -714,6 +722,7 @@ def test_custom_diffusion_xformers_on_off(self):
714722
# enable deterministic behavior for gradient checkpointing
715723
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
716724

725+
init_dict["block_out_channels"] = (16, 32)
717726
init_dict["attention_head_dim"] = (8, 16)
718727

719728
torch.manual_seed(0)
@@ -739,6 +748,7 @@ def test_pickle(self):
739748
# enable deterministic behavior for gradient checkpointing
740749
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
741750

751+
init_dict["block_out_channels"] = (16, 32)
742752
init_dict["attention_head_dim"] = (8, 16)
743753

744754
model = self.model_class(**init_dict)
@@ -770,6 +780,7 @@ def test_asymmetrical_unet(self):
770780
def test_ip_adapter(self):
771781
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
772782

783+
init_dict["block_out_channels"] = (16, 32)
773784
init_dict["attention_head_dim"] = (8, 16)
774785

775786
model = self.model_class(**init_dict)
@@ -842,6 +853,7 @@ def test_ip_adapter(self):
842853
def test_ip_adapter_plus(self):
843854
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
844855

856+
init_dict["block_out_channels"] = (16, 32)
845857
init_dict["attention_head_dim"] = (8, 16)
846858

847859
model = self.model_class(**init_dict)

tests/models/unets/test_models_unet_3d_condition.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,36 +41,37 @@ def dummy_input(self):
4141
batch_size = 4
4242
num_channels = 4
4343
num_frames = 4
44-
sizes = (32, 32)
44+
sizes = (16, 16)
4545

4646
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
4747
time_step = torch.tensor([10]).to(torch_device)
48-
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
48+
encoder_hidden_states = floats_tensor((batch_size, 4, 8)).to(torch_device)
4949

5050
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
5151

5252
@property
5353
def input_shape(self):
54-
return (4, 4, 32, 32)
54+
return (4, 4, 16, 16)
5555

5656
@property
5757
def output_shape(self):
58-
return (4, 4, 32, 32)
58+
return (4, 4, 16, 16)
5959

6060
def prepare_init_args_and_inputs_for_common(self):
6161
init_dict = {
62-
"block_out_channels": (32, 64),
62+
"block_out_channels": (4, 8),
63+
"norm_num_groups": 4,
6364
"down_block_types": (
6465
"CrossAttnDownBlock3D",
6566
"DownBlock3D",
6667
),
6768
"up_block_types": ("UpBlock3D", "CrossAttnUpBlock3D"),
68-
"cross_attention_dim": 32,
69-
"attention_head_dim": 8,
69+
"cross_attention_dim": 8,
70+
"attention_head_dim": 2,
7071
"out_channels": 4,
7172
"in_channels": 4,
7273
"layers_per_block": 1,
73-
"sample_size": 32,
74+
"sample_size": 16,
7475
}
7576
inputs_dict = self.dummy_input
7677
return init_dict, inputs_dict
@@ -93,7 +94,7 @@ def test_xformers_enable_works(self):
9394
# Overriding to set `norm_num_groups` needs to be different for this model.
9495
def test_forward_with_norm_groups(self):
9596
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
96-
97+
init_dict["block_out_channels"] = (32, 64)
9798
init_dict["norm_num_groups"] = 32
9899

99100
model = self.model_class(**init_dict)
@@ -140,6 +141,7 @@ def test_determinism(self):
140141
def test_model_attention_slicing(self):
141142
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
142143

144+
init_dict["block_out_channels"] = (16, 32)
143145
init_dict["attention_head_dim"] = 8
144146

145147
model = self.model_class(**init_dict)
@@ -163,6 +165,7 @@ def test_model_attention_slicing(self):
163165

164166
def test_feed_forward_chunking(self):
165167
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
168+
init_dict["block_out_channels"] = (32, 64)
166169
init_dict["norm_num_groups"] = 32
167170

168171
model = self.model_class(**init_dict)

tests/models/unets/test_models_unet_motion.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -46,34 +46,35 @@ class UNetMotionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase)
4646
def dummy_input(self):
4747
batch_size = 4
4848
num_channels = 4
49-
num_frames = 8
50-
sizes = (32, 32)
49+
num_frames = 4
50+
sizes = (16, 16)
5151

5252
noise = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)
5353
time_step = torch.tensor([10]).to(torch_device)
54-
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
54+
encoder_hidden_states = floats_tensor((batch_size, 4, 16)).to(torch_device)
5555

5656
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
5757

5858
@property
5959
def input_shape(self):
60-
return (4, 8, 32, 32)
60+
return (4, 4, 16, 16)
6161

6262
@property
6363
def output_shape(self):
64-
return (4, 8, 32, 32)
64+
return (4, 4, 16, 16)
6565

6666
def prepare_init_args_and_inputs_for_common(self):
6767
init_dict = {
68-
"block_out_channels": (32, 64),
68+
"block_out_channels": (16, 32),
69+
"norm_num_groups": 16,
6970
"down_block_types": ("CrossAttnDownBlockMotion", "DownBlockMotion"),
7071
"up_block_types": ("UpBlockMotion", "CrossAttnUpBlockMotion"),
71-
"cross_attention_dim": 32,
72-
"num_attention_heads": 4,
72+
"cross_attention_dim": 16,
73+
"num_attention_heads": 2,
7374
"out_channels": 4,
7475
"in_channels": 4,
7576
"layers_per_block": 1,
76-
"sample_size": 32,
77+
"sample_size": 16,
7778
}
7879
inputs_dict = self.dummy_input
7980
return init_dict, inputs_dict
@@ -194,6 +195,7 @@ def _set_gradient_checkpointing_new(self, module, value=False):
194195

195196
def test_feed_forward_chunking(self):
196197
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
198+
init_dict["block_out_channels"] = (32, 64)
197199
init_dict["norm_num_groups"] = 32
198200

199201
model = self.model_class(**init_dict)

tests/models/unets/test_models_unet_spatiotemporal.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from diffusers.utils.testing_utils import (
2525
enable_full_determinism,
2626
floats_tensor,
27+
skip_mps,
2728
torch_all_close,
2829
torch_device,
2930
)
@@ -36,6 +37,7 @@
3637
enable_full_determinism()
3738

3839

40+
@skip_mps
3941
class UNetSpatioTemporalConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.TestCase):
4042
model_class = UNetSpatioTemporalConditionModel
4143
main_input_name = "sample"

0 commit comments

Comments
 (0)