2424    AutoPipelineForInpainting ,
2525    AutoPipelineForText2Image ,
2626    DiffusionPipeline ,
27-     FluxKontextPipeline ,
2827)
2928from  diffusers .pipelines .stable_diffusion  import  StableDiffusionSafetyChecker 
3029from  diffusers .utils  import  load_image 
@@ -485,7 +484,8 @@ class OVPipelineForImage2ImageTest(unittest.TestCase):
485484    if  is_transformers_version (">=" , "4.40.0" ):
486485        SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
487486        SUPPORTED_ARCHITECTURES .append ("flux" )
488-         SUPPORTED_ARCHITECTURES .append ("flux-kontext" )
487+         if  is_diffusers_version (">=" , "0.35.0" ):
488+             SUPPORTED_ARCHITECTURES .append ("flux-kontext" )
489489
490490    AUTOMODEL_CLASS  =  AutoPipelineForImage2Image 
491491    OVMODEL_CLASS  =  OVPipelineForImage2Image 
@@ -502,8 +502,9 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
502502        if  model_type  in  ["flux" , "stable-diffusion-3" , "flux-kontext" ]:
503503            inputs ["height" ] =  height 
504504            inputs ["width" ] =  width 
505- 
506-         inputs ["strength" ] =  0.75 
505+             
506+         if  model_type  !=  "flux-kontext" :
507+             inputs ["strength" ] =  0.75 
507508
508509        return  inputs 
509510
@@ -535,7 +536,15 @@ def test_num_images_per_prompt(self, model_arch: str):
535536                            height = height , width = width , batch_size = batch_size , model_type = model_arch 
536537                        )
537538                        outputs  =  pipeline (** inputs , num_images_per_prompt = num_images_per_prompt ).images 
538-                         self .assertEqual (outputs .shape , (batch_size  *  num_images_per_prompt , height , width , 3 ))
539+                         if  model_arch  !=  "flux-kontext" :
540+                             self .assertEqual (outputs .shape , (batch_size  *  num_images_per_prompt , height , width , 3 ))
541+                         else :
542+                             if  (height  ==  width ):
543+                                 self .assertEqual (outputs .shape , (batch_size  *  num_images_per_prompt , 1024 , 1024 , 3 ))
544+                             elif  (height  >  width ):
545+                                 self .assertEqual (outputs .shape , (batch_size  *  num_images_per_prompt , 1448 , 724 , 3 ))
546+                             else :
547+                                 self .assertEqual (outputs .shape , (batch_size  *  num_images_per_prompt , 724 , 1448 , 3 )) 
539548
540549    @parameterized .expand (["stable-diffusion" , "stable-diffusion-xl" , "latent-consistency" ]) 
541550    @require_diffusers  
@@ -568,8 +577,10 @@ def __call__(self, *args, **kwargs) -> None:
568577    @require_diffusers  
569578    def  test_shape (self , model_arch : str ):
570579        pipeline  =  self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
571- 
572-         height , width , batch_size  =  128 , 64 , 1 
580+         if  model_arch  !=  "flux-kontext" :
581+             height , width , batch_size  =  128 , 64 , 1 
582+         else :
583+             height , width , batch_size  =  1448 , 724 , 1 
573584
574585        for  input_type  in  ["pil" , "np" , "pt" ]:
575586            inputs  =  self .generate_inputs (
@@ -586,7 +597,7 @@ def test_shape(self, model_arch: str):
586597                elif  output_type  ==  "pt" :
587598                    self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
588599                else :
589-                     if  model_arch   !=   "flux"   and   model_arch   !=   "flux-kontext"  :
600+                     if  not   model_arch . startswith ( "flux"  ) :
590601                        out_channels  =  (
591602                            pipeline .unet .config .out_channels 
592603                            if  pipeline .unet  is  not None 
@@ -611,9 +622,9 @@ def test_shape(self, model_arch: str):
611622    @require_diffusers  
612623    def  test_compare_to_diffusers_pipeline (self , model_arch : str ):
613624        height , width , batch_size  =  128 , 128 , 1 
614-         inputs  =  self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )
615- 
616-         auto_cls   =   self . AUTOMODEL_CLASS   if   "flux-kontext"   not   in   model_arch   else   FluxKontextPipeline 
625+         inputs  =  self .generate_inputs (height = height , width = width , batch_size = batch_size , model_type = model_arch )         
626+          auto_cls   =   self . AUTOMODEL_CLASS 
627+         
617628        diffusers_pipeline  =  auto_cls .from_pretrained (MODEL_NAMES [model_arch ])
618629        ov_pipeline  =  self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
619630
0 commit comments