@@ -838,6 +838,133 @@ def apply_overlay(
838838        return  image 
839839
840840
841+ class  InpaintProcessor (ConfigMixin ):
842+     """ 
843+     Image processor for inpainting image and mask. 
844+     """ 
845+     config_name  =  CONFIG_NAME 
846+ 
847+     @register_to_config  
848+     def  __init__ (
849+         self , 
850+         do_resize : bool  =  True ,
851+         vae_scale_factor : int  =  8 ,
852+         vae_latent_channels : int  =  4 ,
853+         resample : str  =  "lanczos" ,
854+         reducing_gap : int  =  None ,
855+         do_normalize : bool  =  True ,
856+         do_binarize : bool  =  False ,
857+         do_convert_grayscale : bool  =  False ,
858+         mask_do_normalize : bool  =  False , 
859+         mask_do_binarize : bool  =  True , 
860+         mask_do_convert_grayscale : bool  =  True ,
861+         ):
862+ 
863+         super ().__init__ ()
864+ 
865+         self ._image_processor  =  VaeImageProcessor (
866+             do_resize = do_resize ,
867+             vae_scale_factor = vae_scale_factor , 
868+             vae_latent_channels = vae_latent_channels ,
869+             resample = resample ,
870+             reducing_gap = reducing_gap ,
871+             do_normalize = do_normalize ,
872+             do_binarize = do_binarize ,
873+             do_convert_grayscale = do_convert_grayscale ,
874+             )
875+         self ._mask_processor  =  VaeImageProcessor (
876+             do_resize = do_resize ,
877+             vae_scale_factor = vae_scale_factor , 
878+             vae_latent_channels = vae_latent_channels ,
879+             resample = resample ,
880+             reducing_gap = reducing_gap ,
881+             do_normalize = mask_do_normalize , 
882+             do_binarize = mask_do_binarize , 
883+             do_convert_grayscale = mask_do_convert_grayscale , 
884+             )
885+ 
886+     
887+     def  preprocess (
888+         self ,
889+         image : PIL .Image .Image ,
890+         mask : PIL .Image .Image ,
891+         height :int ,
892+         width :int ,
893+         padding_mask_crop :Optional [int ] =  None ,
894+     ) ->  Tuple [torch .Tensor , torch .Tensor ]:
895+         """ 
896+         Preprocess the image and mask. 
897+         """ 
898+         
899+         if  padding_mask_crop  is  not None :
900+             crops_coords  =  self ._image_processor .get_crop_region (
901+                 mask , width , height , pad = padding_mask_crop 
902+             )
903+             resize_mode  =  "fill" 
904+         else :
905+             crops_coords  =  None 
906+             resize_mode  =  "default" 
907+         
908+         processed_image  =  self ._image_processor .preprocess (
909+             image ,
910+             height = height ,
911+             width = width ,
912+             crops_coords = crops_coords ,
913+             resize_mode = resize_mode ,
914+         )
915+ 
916+         processed_mask  =  self ._mask_processor .preprocess (
917+             mask ,
918+             height = height ,
919+             width = width ,
920+             resize_mode = resize_mode ,
921+             crops_coords = crops_coords ,
922+         )
923+ 
924+         
925+         if  crops_coords  is  not None :
926+             postprocessing_kwargs  =  {
927+                 "crops_coords" : crops_coords ,
928+                 "original_image" : image ,
929+                 "original_mask" : mask ,
930+             }
931+         else :
932+             postprocessing_kwargs  =  {
933+                 "crops_coords" : None ,
934+                 "original_image" : None ,
935+                 "original_mask" : None ,
936+             }
937+ 
938+         return  processed_image , processed_mask , postprocessing_kwargs 
939+ 
940+     
941+     def  postprocess (
942+         self ,
943+         image : torch .Tensor ,
944+         output_type : str  =  "pil" ,
945+         original_image : Optional [PIL .Image .Image ] =  None ,
946+         original_mask : Optional [PIL .Image .Image ] =  None ,
947+         crops_coords : Optional [Tuple [int , int , int , int ]] =  None ,
948+     ) ->  Tuple [PIL .Image .Image , PIL .Image .Image ]:
949+         """ 
950+         Postprocess the image, optionally apply mask overlay 
951+         """ 
952+         image  =  self ._image_processor .postprocess (
953+             image ,
954+             output_type = output_type ,
955+         )
956+         # optionally apply the mask overlay 
957+         if  crops_coords  is  not None  and  (original_image  is  None  or  original_mask  is  None ):
958+             raise  ValueError ("original_image and original_mask must be provided if crops_coords is provided" )
959+ 
960+         elif  crops_coords  is  not None :
961+             image  =  [self ._image_processor .apply_overlay (
962+                 original_mask , original_image , i , crops_coords 
963+             ) for  i  in  image ]
964+         
965+         return  image 
966+ 
967+ 
841968class  VaeImageProcessorLDM3D (VaeImageProcessor ):
842969    """ 
843970    Image processor for VAE LDM3D. 
0 commit comments