22import  unittest 
33from  typing  import  Callable , Union 
44
5+ from  diffusers .utils .dummy_pt_objects  import  ModularPipeline , ModularPipelineBlocks 
56import  numpy  as  np 
67import  torch 
78
89import  diffusers 
9- from  diffusers  import  (
10-     DiffusionPipeline ,
11- )
1210from  diffusers .utils  import  logging 
1311from  diffusers .utils .testing_utils  import  (
1412    backend_empty_cache ,
@@ -42,7 +40,7 @@ class ModularPipelineTesterMixin:
4240    # Canonical parameters that are passed to `__call__` regardless 
4341    # of the type of pipeline. They are always optional and have common 
4442    # sense default values. 
45-     required_optional_params  =  frozenset (
43+     optional_params  =  frozenset (
4644        [
4745            "num_inference_steps" ,
4846            "num_images_per_prompt" ,
@@ -51,7 +49,7 @@ class ModularPipelineTesterMixin:
5149        ]
5250    )
5351    # this is modular specific: generator needs to be a intermediate input because it's mutable 
54-     required_intermediate_params  =  frozenset (
52+     intermediate_params  =  frozenset (
5553        [
5654            "generator" ,
5755        ]
@@ -63,7 +61,7 @@ def get_generator(self, seed):
6361        return  generator 
6462
6563    @property  
66-     def  pipeline_class (self ) ->  Union [Callable , DiffusionPipeline ]:
64+     def  pipeline_class (self ) ->  Union [Callable , ModularPipeline ]:
6765        raise  NotImplementedError (
6866            "You need to set the attribute `pipeline_class = ClassNameOfPipeline` in the child test class. " 
6967            "See existing pipeline tests for reference." 
@@ -76,7 +74,7 @@ def repo(self) -> str:
7674        )
7775
7876    @property  
79-     def  pipeline_blocks_class (self ) ->  Union [Callable , DiffusionPipeline ]:
77+     def  pipeline_blocks_class (self ) ->  Union [Callable , ModularPipelineBlocks ]:
8078        raise  NotImplementedError (
8179            "You need to set the attribute `pipeline_blocks_class = ClassNameOfPipelineBlocks` in the child test class. " 
8280            "See existing pipeline tests for reference." 
@@ -139,49 +137,21 @@ def tearDown(self):
139137
140138    def  test_pipeline_call_signature (self ):
141139        pipe  =  self .get_pipeline ()
142-         parameters  =  pipe .blocks .input_names 
143-         optional_parameters  =  pipe .default_call_parameters 
140+         input_parameters  =  pipe .blocks .input_names 
144141        intermediate_parameters  =  pipe .blocks .intermediate_input_names 
142+         optional_parameters  =  pipe .default_call_parameters 
145143
146-         remaining_required_parameters  =  set ()
147- 
148-         for  param  in  self .params :
149-             if  param  not  in   parameters :
150-                 remaining_required_parameters .add (param )
151- 
152-         self .assertTrue (
153-             len (remaining_required_parameters ) ==  0 ,
154-             f"Required parameters not present: { remaining_required_parameters }  " ,
155-         )
156- 
157-         remaining_required_intermediate_parameters  =  set ()
158- 
159-         for  param  in  self .required_intermediate_params :
160-             if  param  not  in   intermediate_parameters :
161-                 remaining_required_intermediate_parameters .add (param )
162- 
163-         self .assertTrue (
164-             len (remaining_required_intermediate_parameters ) ==  0 ,
165-             f"Required intermediate parameters not present: { remaining_required_intermediate_parameters }  " ,
166-         )
167- 
168-         remaining_required_optional_parameters  =  set ()
169- 
170-         for  param  in  self .required_optional_params :
171-             if  param  not  in   optional_parameters :
172-                 remaining_required_optional_parameters .add (param )
173- 
174-         self .assertTrue (
175-             len (remaining_required_optional_parameters ) ==  0 ,
176-             f"Required optional parameters not present: { remaining_required_optional_parameters }  " ,
177-         )
144+         def  _check_for_parameters (parameters , expected_parameters , param_type ):
145+             remaining_parameters  =  set (param  for  param  in  parameters  if  param  not  in   expected_parameters )
146+             assert  (
147+                 len (remaining_parameters ) ==  0 
148+             ), f"Required { param_type }   parameters not present: { remaining_parameters }  " 
178149
179-     def  test_inference_batch_consistent (self , batch_sizes = [2 ]):
180-         self ._test_inference_batch_consistent (batch_sizes = batch_sizes )
150+         _check_for_parameters (self .params , input_parameters , "input" )
151+         _check_for_parameters (self .intermediate_params , intermediate_parameters , "intermediate" )
152+         _check_for_parameters (self .optional_params , optional_parameters , "optional" )
181153
182-     def  _test_inference_batch_consistent (
183-         self , batch_sizes = [2 ], additional_params_copy_to_batched_inputs = ["num_inference_steps" ], batch_generator = True 
184-     ):
154+     def  test_inference_batch_consistent (self , batch_sizes = [2 ], batch_generator = True ):
185155        pipe  =  self .get_pipeline ()
186156        pipe .to (torch_device )
187157        pipe .set_progress_bar_config (disable = None )
@@ -203,16 +173,7 @@ def _test_inference_batch_consistent(
203173                    continue 
204174
205175                value  =  inputs [name ]
206-                 if  name  ==  "prompt" :
207-                     len_prompt  =  len (value )
208-                     # make unequal batch sizes 
209-                     batched_input [name ] =  [value [: len_prompt  //  i ] for  i  in  range (1 , batch_size  +  1 )]
210- 
211-                     # make last batch super long 
212-                     batched_input [name ][- 1 ] =  100  *  "very long" 
213- 
214-                 else :
215-                     batched_input [name ] =  batch_size  *  [value ]
176+                 batched_input [name ] =  batch_size  *  [value ]
216177
217178            if  batch_generator  and  "generator"  in  inputs :
218179                batched_input ["generator" ] =  [self .get_generator (i ) for  i  in  range (batch_size )]
@@ -225,21 +186,18 @@ def _test_inference_batch_consistent(
225186        logger .setLevel (level = diffusers .logging .WARNING )
226187        for  batch_size , batched_input  in  zip (batch_sizes , batched_inputs ):
227188            output  =  pipe (** batched_input , output = "images" )
228-             assert  len (output ) ==  batch_size 
229- 
230-     def  test_inference_batch_single_identical (self , batch_size = 3 , expected_max_diff = 1e-4 ):
231-         self ._test_inference_batch_single_identical (batch_size = batch_size , expected_max_diff = expected_max_diff )
189+             assert  len (output ) ==  batch_size , "Output is different from expected batch size" 
232190
233-     def  _test_inference_batch_single_identical (
191+     def  test_batch_inference_identical_to_single (
234192        self ,
235193        batch_size = 2 ,
236194        expected_max_diff = 1e-4 ,
237-         additional_params_copy_to_batched_inputs = ["num_inference_steps" ],
238195    ):
239196        pipe  =  self .get_pipeline ()
240197        pipe .to (torch_device )
241198        pipe .set_progress_bar_config (disable = None )
242199        inputs  =  self .get_dummy_inputs (torch_device )
200+ 
243201        # Reset generator in case it is has been used in self.get_dummy_inputs 
244202        inputs ["generator" ] =  self .get_generator (0 )
245203
@@ -255,40 +213,30 @@ def _test_inference_batch_single_identical(
255213                continue 
256214
257215            value  =  inputs [name ]
258-             if  name  ==  "prompt" :
259-                 len_prompt  =  len (value )
260-                 batched_inputs [name ] =  [value [: len_prompt  //  i ] for  i  in  range (1 , batch_size  +  1 )]
261-                 batched_inputs [name ][- 1 ] =  100  *  "very long" 
262- 
263-             else :
264-                 batched_inputs [name ] =  batch_size  *  [value ]
216+             batched_inputs [name ] =  batch_size  *  [value ]
265217
266218        if  "generator"  in  inputs :
267219            batched_inputs ["generator" ] =  [self .get_generator (i ) for  i  in  range (batch_size )]
268220
269221        if  "batch_size"  in  inputs :
270222            batched_inputs ["batch_size" ] =  batch_size 
271223
272-         for  arg  in  additional_params_copy_to_batched_inputs :
273-             batched_inputs [arg ] =  inputs [arg ]
274- 
275224        output  =  pipe (** inputs , output = "images" )
276225        output_batch  =  pipe (** batched_inputs , output = "images" )
277226
278227        assert  output_batch .shape [0 ] ==  batch_size 
279228
280229        max_diff  =  np .abs (to_np (output_batch [0 ]) -  to_np (output [0 ])).max ()
281-         assert  max_diff  <  expected_max_diff 
230+         assert  max_diff  <  expected_max_diff ,  "Batch inference results different from single inference results" 
282231
283232    @unittest .skipIf (torch_device  not  in   ["cuda" , "xpu" ], reason = "float16 requires CUDA or XPU" ) 
284233    @require_accelerator  
285234    def  test_float16_inference (self , expected_max_diff = 5e-2 ):
286-         pipe  =  self .get_pipeline (torch_dtype = torch .float32 )
287- 
288-         pipe .to (torch_device )
235+         pipe  =  self .get_pipeline ()
236+         pipe .to (torch_device , torch .float32 )
289237        pipe .set_progress_bar_config (disable = None )
290238
291-         pipe_fp16  =  self .get_pipeline (torch_dtype = torch . float16 )
239+         pipe_fp16  =  self .get_pipeline ()
292240        pipe_fp16 .to (torch_device , torch .float16 )
293241        pipe_fp16 .set_progress_bar_config (disable = None )
294242
@@ -309,7 +257,7 @@ def test_float16_inference(self, expected_max_diff=5e-2):
309257            output_fp16  =  output_fp16 .cpu ()
310258
311259        max_diff  =  numpy_cosine_similarity_distance (output .flatten (), output_fp16 .flatten ())
312-         assert  max_diff  <  expected_max_diff 
260+         assert  max_diff  <  expected_max_diff ,  "FP16 inference is different from FP32 inference" 
313261
314262    @require_accelerator  
315263    def  test_to_device (self ):
@@ -320,19 +268,32 @@ def test_to_device(self):
320268        model_devices  =  [
321269            component .device .type  for  component  in  pipe .components .values () if  hasattr (component , "device" )
322270        ]
323-         self .assertTrue (all (device  ==  "cpu"  for  device  in  model_devices ))
324- 
325-         output_cpu  =  pipe (** self .get_dummy_inputs ("cpu" ), output = "images" )
326-         self .assertTrue (np .isnan (output_cpu ).sum () ==  0 )
271+         assert  all (device  ==  "cpu"  for  device  in  model_devices ), "All pipeline components are not on CPU" 
327272
328273        pipe .to (torch_device )
329274        model_devices  =  [
330275            component .device .type  for  component  in  pipe .components .values () if  hasattr (component , "device" )
331276        ]
332-         self .assertTrue (all (device  ==  torch_device  for  device  in  model_devices ))
277+         assert  all (
278+             device  ==  torch_device  for  device  in  model_devices 
279+         ), "All pipeline components are not on accelerator device" 
280+ 
281+     def  test_inference_is_not_nan_cpu (self ):
282+         pipe  =  self .get_pipeline ()
283+         pipe .set_progress_bar_config (disable = None )
284+         pipe .to ("cpu" )
285+ 
286+         output  =  pipe (** self .get_dummy_inputs ("cpu" ), output = "np" )
287+         assert  np .isnan (to_np (output )).sum () ==  0 , "CPU Inference returns NaN" 
288+ 
289+     @require_accelerator  
290+     def  test_inferece_is_not_nan (self ):
291+         pipe  =  self .get_pipeline ()
292+         pipe .set_progress_bar_config (disable = None )
293+         pipe .to (torch_device )
333294
334-         output_device  =  pipe (** self .get_dummy_inputs (torch_device ), output = "images " )
335-         self . assertTrue ( np .isnan (to_np (output_device )).sum () ==  0 ) 
295+         output  =  pipe (** self .get_dummy_inputs (torch_device ), output = "np " )
296+         assert   np .isnan (to_np (output )).sum () ==  0 ,  "Accelerator Inference returns NaN" 
336297
337298    def  test_num_images_per_prompt (self ):
338299        pipe  =  self .get_pipeline ()
0 commit comments