@@ -247,33 +247,34 @@ class UNet2DConditionModelTests(ModelTesterMixin, UNetTesterMixin, unittest.Test
247247 def dummy_input (self ):
248248 batch_size = 4
249249 num_channels = 4
250- sizes = (32 , 32 )
250+ sizes = (16 , 16 )
251251
252252 noise = floats_tensor ((batch_size , num_channels ) + sizes ).to (torch_device )
253253 time_step = torch .tensor ([10 ]).to (torch_device )
254- encoder_hidden_states = floats_tensor ((batch_size , 4 , 32 )).to (torch_device )
254+ encoder_hidden_states = floats_tensor ((batch_size , 4 , 8 )).to (torch_device )
255255
256256 return {"sample" : noise , "timestep" : time_step , "encoder_hidden_states" : encoder_hidden_states }
257257
258258 @property
259259 def input_shape (self ):
260- return (4 , 32 , 32 )
260+ return (4 , 16 , 16 )
261261
262262 @property
263263 def output_shape (self ):
264- return (4 , 32 , 32 )
264+ return (4 , 16 , 16 )
265265
266266 def prepare_init_args_and_inputs_for_common (self ):
267267 init_dict = {
268- "block_out_channels" : (32 , 64 ),
268+ "block_out_channels" : (4 , 8 ),
269+ "norm_num_groups" : 4 ,
269270 "down_block_types" : ("CrossAttnDownBlock2D" , "DownBlock2D" ),
270271 "up_block_types" : ("UpBlock2D" , "CrossAttnUpBlock2D" ),
271- "cross_attention_dim" : 32 ,
272- "attention_head_dim" : 8 ,
272+ "cross_attention_dim" : 8 ,
273+ "attention_head_dim" : 2 ,
273274 "out_channels" : 4 ,
274275 "in_channels" : 4 ,
275- "layers_per_block" : 2 ,
276- "sample_size" : 32 ,
276+ "layers_per_block" : 1 ,
277+ "sample_size" : 16 ,
277278 }
278279 inputs_dict = self .dummy_input
279280 return init_dict , inputs_dict
@@ -337,6 +338,7 @@ def test_gradient_checkpointing(self):
337338 def test_model_with_attention_head_dim_tuple (self ):
338339 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
339340
341+ init_dict ["block_out_channels" ] = (16 , 32 )
340342 init_dict ["attention_head_dim" ] = (8 , 16 )
341343
342344 model = self .model_class (** init_dict )
@@ -375,7 +377,7 @@ def test_model_with_use_linear_projection(self):
375377 def test_model_with_cross_attention_dim_tuple (self ):
376378 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
377379
378- init_dict ["cross_attention_dim" ] = (32 , 32 )
380+ init_dict ["cross_attention_dim" ] = (8 , 8 )
379381
380382 model = self .model_class (** init_dict )
381383 model .to (torch_device )
@@ -443,6 +445,7 @@ def test_model_with_class_embeddings_concat(self):
443445 def test_model_attention_slicing (self ):
444446 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
445447
448+ init_dict ["block_out_channels" ] = (16 , 32 )
446449 init_dict ["attention_head_dim" ] = (8 , 16 )
447450
448451 model = self .model_class (** init_dict )
@@ -467,6 +470,7 @@ def test_model_attention_slicing(self):
467470 def test_model_sliceable_head_dim (self ):
468471 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
469472
473+ init_dict ["block_out_channels" ] = (16 , 32 )
470474 init_dict ["attention_head_dim" ] = (8 , 16 )
471475
472476 model = self .model_class (** init_dict )
@@ -485,6 +489,7 @@ def check_sliceable_dim_attr(module: torch.nn.Module):
485489 def test_gradient_checkpointing_is_applied (self ):
486490 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
487491
492+ init_dict ["block_out_channels" ] = (16 , 32 )
488493 init_dict ["attention_head_dim" ] = (8 , 16 )
489494
490495 model_class_copy = copy .copy (self .model_class )
@@ -561,6 +566,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
561566 # enable deterministic behavior for gradient checkpointing
562567 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
563568
569+ init_dict ["block_out_channels" ] = (16 , 32 )
564570 init_dict ["attention_head_dim" ] = (8 , 16 )
565571
566572 model = self .model_class (** init_dict )
@@ -571,7 +577,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
571577 model .set_attn_processor (processor )
572578 model (** inputs_dict , cross_attention_kwargs = {"number" : 123 }).sample
573579
574- assert processor .counter == 12
580+ assert processor .counter == 8
575581 assert processor .is_run
576582 assert processor .number == 123
577583
@@ -587,7 +593,7 @@ def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_ma
587593 def test_model_xattn_mask (self , mask_dtype ):
588594 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
589595
590- model = self .model_class (** {** init_dict , "attention_head_dim" : (8 , 16 )})
596+ model = self .model_class (** {** init_dict , "attention_head_dim" : (8 , 16 ), "block_out_channels" : ( 16 , 32 ) })
591597 model .to (torch_device )
592598 model .eval ()
593599
@@ -649,6 +655,7 @@ def test_custom_diffusion_processors(self):
649655 # enable deterministic behavior for gradient checkpointing
650656 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
651657
658+ init_dict ["block_out_channels" ] = (16 , 32 )
652659 init_dict ["attention_head_dim" ] = (8 , 16 )
653660
654661 model = self .model_class (** init_dict )
@@ -675,6 +682,7 @@ def test_custom_diffusion_save_load(self):
675682 # enable deterministic behavior for gradient checkpointing
676683 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
677684
685+ init_dict ["block_out_channels" ] = (16 , 32 )
678686 init_dict ["attention_head_dim" ] = (8 , 16 )
679687
680688 torch .manual_seed (0 )
@@ -714,6 +722,7 @@ def test_custom_diffusion_xformers_on_off(self):
714722 # enable deterministic behavior for gradient checkpointing
715723 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
716724
725+ init_dict ["block_out_channels" ] = (16 , 32 )
717726 init_dict ["attention_head_dim" ] = (8 , 16 )
718727
719728 torch .manual_seed (0 )
@@ -739,6 +748,7 @@ def test_pickle(self):
739748 # enable deterministic behavior for gradient checkpointing
740749 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
741750
751+ init_dict ["block_out_channels" ] = (16 , 32 )
742752 init_dict ["attention_head_dim" ] = (8 , 16 )
743753
744754 model = self .model_class (** init_dict )
@@ -770,6 +780,7 @@ def test_asymmetrical_unet(self):
770780 def test_ip_adapter (self ):
771781 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
772782
783+ init_dict ["block_out_channels" ] = (16 , 32 )
773784 init_dict ["attention_head_dim" ] = (8 , 16 )
774785
775786 model = self .model_class (** init_dict )
@@ -842,6 +853,7 @@ def test_ip_adapter(self):
842853 def test_ip_adapter_plus (self ):
843854 init_dict , inputs_dict = self .prepare_init_args_and_inputs_for_common ()
844855
856+ init_dict ["block_out_channels" ] = (16 , 32 )
845857 init_dict ["attention_head_dim" ] = (8 , 16 )
846858
847859 model = self .model_class (** init_dict )
0 commit comments