@@ -247,33 +247,34 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
247
247
def dummy_input (self ):
248
248
batch_size = 4
249
249
num_channels = 4
250
- sizes = (32 , 32 )
250
+ sizes = (16 , 16 )
251
251
252
252
noise = floats_tensor ((batch_size , num_channels ) + sizes ).to (torch_device )
253
253
time_step = torch .tensor ([10 ]).to (torch_device )
254
- encoder_hidden_states = floats_tensor ((batch_size , 4 , 32 )).to (torch_device )
254
+ encoder_hidden_states = floats_tensor ((batch_size , 4 , 8 )).to (torch_device )
255
255
256
256
return {"sample" : noise , "timestep" : time_step , "encoder_hidden_states" : encoder_hidden_states }
257
257
258
258
@property
259
259
def input_shape (self ):
260
- return (4 , 32 , 32 )
260
+ return (4 , 16 , 16 )
261
261
262
262
@property
263
263
def output_shape (self ):
264
- return (4 , 32 , 32 )
264
+ return (4 , 16 , 16 )
265
265
266
266
def prepare_init_args_and_inputs_for_common (self ):
267
267
init_dict = {
268
- "block_out_channels" : (32 , 64 ),
268
+ "block_out_channels" : (4 , 8 ),
269
+ "norm_num_groups" : 4 ,
269
270
"down_block_types" : ("CrossAttnDownBlock2D" , "DownBlock2D" ),
270
271
"up_block_types" : ("UpBlock2D" , "CrossAttnUpBlock2D" ),
271
- "cross_attention_dim" : 32 ,
272
- "attention_head_dim" : 8 ,
272
+ "cross_attention_dim" : 8 ,
273
+ "attention_head_dim" : 2 ,
273
274
"out_channels" : 4 ,
274
275
"in_channels" : 4 ,
275
- "layers_per_block" : 2 ,
276
- "sample_size" : 32 ,
276
+ "layers_per_block" : 1 ,
277
+ "sample_size" : 16 ,
277
278
}
278
279
inputs_dict = self .dummy_input
279
280
return init_dict , inputs_dict
@@ -337,6 +338,7 @@ def test_gradient_checkpointing(self):
337
338
def test_model_with_attention_head_dim_tuple (self ):
338
339
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
339
340
341
+ init_dict ["block_out_channels" ] = (16 , 32 )
340
342
init_dict ["attention_head_dim" ] = (8 , 16 )
341
343
342
344
model = self .model_class (** init_dict )
@@ -375,7 +377,7 @@ def test_model_with_use_linear_projection(self):
375
377
def test_model_with_cross_attention_dim_tuple (self ):
376
378
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
377
379
378
- init_dict ["cross_attention_dim" ] = (32 , 32 )
380
+ init_dict ["cross_attention_dim" ] = (8 , 8 )
379
381
380
382
model = self .model_class (** init_dict )
381
383
model .to (torch_device )
@@ -443,6 +445,7 @@ def test_model_with_class_embeddings_concat(self):
443
445
def test_model_attention_slicing (self ):
444
446
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
445
447
448
+ init_dict ["block_out_channels" ] = (16 , 32 )
446
449
init_dict ["attention_head_dim" ] = (8 , 16 )
447
450
448
451
model = self .model_class (** init_dict )
@@ -467,6 +470,7 @@ def test_model_attention_slicing(self):
467
470
def test_model_sliceable_head_dim (self ):
468
471
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
469
472
473
+ init_dict ["block_out_channels" ] = (16 , 32 )
470
474
init_dict ["attention_head_dim" ] = (8 , 16 )
471
475
472
476
model = self .model_class (** init_dict )
@@ -485,6 +489,7 @@ def check_sliceable_dim_attr(module: torch.nn.Module):
485
489
def test_gradient_checkpointing_is_applied (self ):
486
490
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
487
491
492
+ init_dict ["block_out_channels" ] = (16 , 32 )
488
493
init_dict ["attention_head_dim" ] = (8 , 16 )
489
494
490
495
model_class_copy = copy .copy (self .model_class )
@@ -561,6 +566,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
561
566
# enable deterministic behavior for gradient checkpointing
562
567
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
563
568
569
+ init_dict ["block_out_channels" ] = (16 , 32 )
564
570
init_dict ["attention_head_dim" ] = (8 , 16 )
565
571
566
572
model = self .model_class (** init_dict )
@@ -571,7 +577,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
571
577
model .set_attn_processor (processor )
572
578
model (** inputs_dict , cross_attention_kwargs = {"number" : 123 }).sample
573
579
574
- assert processor .counter == 12
580
+ assert processor .counter == 8
575
581
assert processor .is_run
576
582
assert processor .number == 123
577
583
@@ -587,7 +593,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
587
593
def test_model_xattn_mask (self , mask_dtype ):
588
594
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
589
595
590
- model = self .model_class (** {** init_dict , "attention_head_dim" : (8 , 16 )})
596
+ model = self .model_class (** {** init_dict , "attention_head_dim" : (8 , 16 ), "block_out_channels" : ( 16 , 32 ) })
591
597
model .to (torch_device )
592
598
model .eval ()
593
599
@@ -649,6 +655,7 @@ def test_custom_diffusion_processors(self):
649
655
# enable deterministic behavior for gradient checkpointing
650
656
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
651
657
658
+ init_dict ["block_out_channels" ] = (16 , 32 )
652
659
init_dict ["attention_head_dim" ] = (8 , 16 )
653
660
654
661
model = self .model_class (** init_dict )
@@ -675,6 +682,7 @@ def test_custom_diffusion_save_load(self):
675
682
# enable deterministic behavior for gradient checkpointing
676
683
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
677
684
685
+ init_dict ["block_out_channels" ] = (16 , 32 )
678
686
init_dict ["attention_head_dim" ] = (8 , 16 )
679
687
680
688
torch .manual_seed (0 )
@@ -714,6 +722,7 @@ def test_custom_diffusion_xformers_on_off(self):
714
722
# enable deterministic behavior for gradient checkpointing
715
723
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
716
724
725
+ init_dict ["block_out_channels" ] = (16 , 32 )
717
726
init_dict ["attention_head_dim" ] = (8 , 16 )
718
727
719
728
torch .manual_seed (0 )
@@ -739,6 +748,7 @@ def test_pickle(self):
739
748
# enable deterministic behavior for gradient checkpointing
740
749
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
741
750
751
+ init_dict ["block_out_channels" ] = (16 , 32 )
742
752
init_dict ["attention_head_dim" ] = (8 , 16 )
743
753
744
754
model = self .model_class (** init_dict )
@@ -770,6 +780,7 @@ def test_asymmetrical_unet(self):
770
780
def test_ip_adapter (self ):
771
781
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
772
782
783
+ init_dict ["block_out_channels" ] = (16 , 32 )
773
784
init_dict ["attention_head_dim" ] = (8 , 16 )
774
785
775
786
model = self .model_class (** init_dict )
@@ -842,6 +853,7 @@ def test_ip_adapter(self):
842
853
def test_ip_adapter_plus (self ):
843
854
init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
844
855
856
+ init_dict ["block_out_channels" ] = (16 , 32 )
845
857
init_dict ["attention_head_dim" ] = (8 , 16 )
846
858
847
859
model = self .model_class (** init_dict )
0 commit comments