@@ -527,7 +527,9 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
527527
528528        The following scenarios are tested: 
529529          - Single IP-Adapter with scale=0 should produce same output as no IP-Adapter. 
530+           - Multi IP-Adapter with scale=0 should produce same output as no IP-Adapter. 
530531          - Single IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. 
532+           - Multi IP-Adapter with scale!=0 should produce different output compared to no IP-Adapter. 
531533        """ 
532534        # Raising the tolerance for this test when it's run on a CPU because we 
533535        # compare against static slices and that can be shaky (with a VVVV low probability). 
@@ -545,6 +547,7 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
545547        else :
546548            output_without_adapter  =  expected_pipe_slice 
547549
550+         # 1. Single IP-Adapter test cases 
548551        adapter_state_dict  =  create_flux_ip_adapter_state_dict (pipe .transformer )
549552        pipe .transformer ._load_ip_adapter_weights (adapter_state_dict )
550553
@@ -578,6 +581,44 @@ def test_ip_adapter(self, expected_max_diff: float = 1e-4, expected_pipe_slice=N
578581            max_diff_with_adapter_scale , 1e-2 , "Output with ip-adapter must be different from normal inference" 
579582        )
580583
584+         # 2. Multi IP-Adapter test cases 
585+         adapter_state_dict_1  =  create_flux_ip_adapter_state_dict (pipe .transformer )
586+         adapter_state_dict_2  =  create_flux_ip_adapter_state_dict (pipe .transformer )
587+         pipe .transformer ._load_ip_adapter_weights ([adapter_state_dict_1 , adapter_state_dict_2 ])
588+ 
589+         # forward pass with multi ip adapter, but scale=0 which should have no effect 
590+         inputs  =  self ._modify_inputs_for_ip_adapter_test (self .get_dummy_inputs (torch_device ))
591+         inputs ["ip_adapter_image_embeds" ] =  [self ._get_dummy_image_embeds (image_embed_dim )] *  2 
592+         inputs ["negative_ip_adapter_image_embeds" ] =  [self ._get_dummy_image_embeds (image_embed_dim )] *  2 
593+         pipe .set_ip_adapter_scale ([0.0 , 0.0 ])
594+         output_without_multi_adapter_scale  =  pipe (** inputs )[0 ]
595+         if  expected_pipe_slice  is  not None :
596+             output_without_multi_adapter_scale  =  output_without_multi_adapter_scale [0 , - 3 :, - 3 :, - 1 ].flatten ()
597+ 
598+         # forward pass with multi ip adapter, but with scale of adapter weights 
599+         inputs  =  self ._modify_inputs_for_ip_adapter_test (self .get_dummy_inputs (torch_device ))
600+         inputs ["ip_adapter_image_embeds" ] =  [self ._get_dummy_image_embeds (image_embed_dim )] *  2 
601+         inputs ["negative_ip_adapter_image_embeds" ] =  [self ._get_dummy_image_embeds (image_embed_dim )] *  2 
602+         pipe .set_ip_adapter_scale ([42.0 , 42.0 ])
603+         output_with_multi_adapter_scale  =  pipe (** inputs )[0 ]
604+         if  expected_pipe_slice  is  not None :
605+             output_with_multi_adapter_scale  =  output_with_multi_adapter_scale [0 , - 3 :, - 3 :, - 1 ].flatten ()
606+ 
607+         max_diff_without_multi_adapter_scale  =  np .abs (
608+             output_without_multi_adapter_scale  -  output_without_adapter 
609+         ).max ()
610+         max_diff_with_multi_adapter_scale  =  np .abs (output_with_multi_adapter_scale  -  output_without_adapter ).max ()
611+         self .assertLess (
612+             max_diff_without_multi_adapter_scale ,
613+             expected_max_diff ,
614+             "Output without multi-ip-adapter must be same as normal inference" ,
615+         )
616+         self .assertGreater (
617+             max_diff_with_multi_adapter_scale ,
618+             1e-2 ,
619+             "Output with multi-ip-adapter scale must be different from normal inference" ,
620+         )
621+ 
581622
582623class  PipelineLatentTesterMixin :
583624    """ 
0 commit comments