@@ -667,13 +667,14 @@ class OVPipelineForInpaintingTest(unittest.TestCase):
667667 if is_transformers_version (">=" , "4.40.0" ):
668668 SUPPORTED_ARCHITECTURES .append ("stable-diffusion-3" )
669669 SUPPORTED_ARCHITECTURES .append ("flux" )
670+ SUPPORTED_ARCHITECTURES .append ("flux-fill" )
670671
671672 AUTOMODEL_CLASS = AutoPipelineForInpainting
672673 OVMODEL_CLASS = OVPipelineForInpainting
673674
674675 TASK = "inpainting"
675676
676- def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" ):
677+ def generate_inputs (self , height = 128 , width = 128 , batch_size = 1 , channel = 3 , input_type = "pil" , model_arch = "" ):
677678 inputs = _generate_prompts (batch_size = batch_size )
678679
679680 inputs ["image" ] = _generate_images (
@@ -683,7 +684,8 @@ def generate_inputs(self, height=128, width=128, batch_size=1, channel=3, input_
683684 height = height , width = width , batch_size = batch_size , channel = 1 , input_type = input_type
684685 )
685686
686- inputs ["strength" ] = 0.75
687+ if model_arch != "flux-fill" :
688+ inputs ["strength" ] = 0.75
687689 inputs ["height" ] = height
688690 inputs ["width" ] = width
689691
@@ -699,7 +701,12 @@ def test_load_vanilla_model_which_is_not_supported(self):
699701 @parameterized .expand (SUPPORTED_ARCHITECTURES )
700702 @require_diffusers
701703 def test_ov_pipeline_class_dispatch (self , model_arch : str ):
702- auto_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
704+ if model_arch != "flux-fill" :
705+ auto_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
706+ else :
707+ from diffusers import FluxFillPipeline
708+
709+ auto_pipeline = FluxFillPipeline .from_pretrained (MODEL_NAMES [model_arch ])
703710 ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
704711
705712 self .assertEqual (ov_pipeline .auto_model_class , auto_pipeline .__class__ )
@@ -713,7 +720,9 @@ def test_num_images_per_prompt(self, model_arch: str):
713720 for height in [64 , 128 ]:
714721 for width in [64 , 128 ]:
715722 for num_images_per_prompt in [1 , 3 ]:
716- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
723+ inputs = self .generate_inputs (
724+ height = height , width = width , batch_size = batch_size , model_arch = model_arch
725+ )
717726 outputs = pipeline (** inputs , num_images_per_prompt = num_images_per_prompt ).images
718727 self .assertEqual (outputs .shape , (batch_size * num_images_per_prompt , height , width , 3 ))
719728
@@ -752,7 +761,9 @@ def test_shape(self, model_arch: str):
752761 height , width , batch_size = 128 , 64 , 1
753762
754763 for input_type in ["pil" , "np" , "pt" ]:
755- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , input_type = input_type )
764+ inputs = self .generate_inputs (
765+ height = height , width = width , batch_size = batch_size , input_type = input_type , model_arch = model_arch
766+ )
756767
757768 for output_type in ["pil" , "np" , "pt" , "latent" ]:
758769 inputs ["output_type" ] = output_type
@@ -764,7 +775,7 @@ def test_shape(self, model_arch: str):
764775 elif output_type == "pt" :
765776 self .assertEqual (outputs .shape , (batch_size , 3 , height , width ))
766777 else :
767- if model_arch != "flux" :
778+ if not model_arch . startswith ( "flux" ) :
768779 out_channels = (
769780 pipeline .unet .config .out_channels
770781 if pipeline .unet is not None
@@ -782,17 +793,26 @@ def test_shape(self, model_arch: str):
782793 else :
783794 packed_height = height // pipeline .vae_scale_factor // 2
784795 packed_width = width // pipeline .vae_scale_factor // 2
785- channels = pipeline .transformer .config .in_channels
796+ channels = (
797+ pipeline .transformer .config .in_channels
798+ if model_arch != "flux-fill"
799+ else pipeline .transformer .out_channels
800+ )
786801 self .assertEqual (outputs .shape , (batch_size , packed_height * packed_width , channels ))
787802
788803 @parameterized .expand (SUPPORTED_ARCHITECTURES )
789804 @require_diffusers
790805 def test_compare_to_diffusers_pipeline (self , model_arch : str ):
791806 ov_pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
792- diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
807+ if model_arch != "flux-fill" :
808+ diffusers_pipeline = self .AUTOMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
809+ else :
810+ from diffusers import FluxFillPipeline
811+
812+ diffusers_pipeline = FluxFillPipeline .from_pretrained (MODEL_NAMES [model_arch ])
793813
794814 height , width , batch_size = 64 , 64 , 1
795- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
815+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_arch = model_arch )
796816
797817 for output_type in ["latent" , "np" , "pt" ]:
798818 inputs ["output_type" ] = output_type
@@ -804,7 +824,7 @@ def test_compare_to_diffusers_pipeline(self, model_arch: str):
804824
805825 # test generation when input resolution nondevisible on 64
806826 height , width , batch_size = 96 , 96 , 1
807- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
827+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_arch = model_arch )
808828
809829 for output_type in ["latent" , "np" , "pt" ]:
810830 inputs ["output_type" ] = output_type
@@ -820,7 +840,7 @@ def test_image_reproducibility(self, model_arch: str):
820840 pipeline = self .OVMODEL_CLASS .from_pretrained (MODEL_NAMES [model_arch ])
821841
822842 height , width , batch_size = 64 , 64 , 1
823- inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size )
843+ inputs = self .generate_inputs (height = height , width = width , batch_size = batch_size , model_arch = model_arch )
824844
825845 for generator_framework in ["np" , "pt" ]:
826846 ov_outputs_1 = pipeline (** inputs , generator = get_generator (generator_framework , SEED ))
0 commit comments