Skip to content

Commit 0a9f7f9

Browse files
committed
up
1 parent 84dbf17 commit 0a9f7f9

File tree

7 files changed

+1074
-378
lines changed

7 files changed

+1074
-378
lines changed

src/diffusers/image_processor.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -838,6 +838,133 @@ def apply_overlay(
838838
return image
839839

840840

841+
class InpaintProcessor(ConfigMixin):
842+
"""
843+
Image processor for inpainting image and mask.
844+
"""
845+
config_name = CONFIG_NAME
846+
847+
@register_to_config
848+
def __init__(
849+
self,
850+
do_resize: bool = True,
851+
vae_scale_factor: int = 8,
852+
vae_latent_channels: int = 4,
853+
resample: str = "lanczos",
854+
reducing_gap: int = None,
855+
do_normalize: bool = True,
856+
do_binarize: bool = False,
857+
do_convert_grayscale: bool = False,
858+
mask_do_normalize: bool = False,
859+
mask_do_binarize: bool = True,
860+
mask_do_convert_grayscale: bool = True,
861+
):
862+
863+
super().__init__()
864+
865+
self._image_processor = VaeImageProcessor(
866+
do_resize=do_resize,
867+
vae_scale_factor=vae_scale_factor,
868+
vae_latent_channels=vae_latent_channels,
869+
resample=resample,
870+
reducing_gap=reducing_gap,
871+
do_normalize=do_normalize,
872+
do_binarize=do_binarize,
873+
do_convert_grayscale=do_convert_grayscale,
874+
)
875+
self._mask_processor = VaeImageProcessor(
876+
do_resize=do_resize,
877+
vae_scale_factor=vae_scale_factor,
878+
vae_latent_channels=vae_latent_channels,
879+
resample=resample,
880+
reducing_gap=reducing_gap,
881+
do_normalize=mask_do_normalize,
882+
do_binarize=mask_do_binarize,
883+
do_convert_grayscale=mask_do_convert_grayscale,
884+
)
885+
886+
887+
def preprocess(
888+
self,
889+
image: PIL.Image.Image,
890+
mask: PIL.Image.Image,
891+
height:int,
892+
width:int,
893+
padding_mask_crop:Optional[int] = None,
894+
) -> Tuple[torch.Tensor, torch.Tensor]:
895+
"""
896+
Preprocess the image and mask.
897+
"""
898+
899+
if padding_mask_crop is not None:
900+
crops_coords = self._image_processor.get_crop_region(
901+
mask, width, height, pad=padding_mask_crop
902+
)
903+
resize_mode = "fill"
904+
else:
905+
crops_coords = None
906+
resize_mode = "default"
907+
908+
processed_image = self._image_processor.preprocess(
909+
image,
910+
height=height,
911+
width=width,
912+
crops_coords=crops_coords,
913+
resize_mode=resize_mode,
914+
)
915+
916+
processed_mask = self._mask_processor.preprocess(
917+
mask,
918+
height=height,
919+
width=width,
920+
resize_mode=resize_mode,
921+
crops_coords=crops_coords,
922+
)
923+
924+
925+
if crops_coords is not None:
926+
postprocessing_kwargs = {
927+
"crops_coords": crops_coords,
928+
"original_image": image,
929+
"original_mask": mask,
930+
}
931+
else:
932+
postprocessing_kwargs = {
933+
"crops_coords": None,
934+
"original_image": None,
935+
"original_mask": None,
936+
}
937+
938+
return processed_image, processed_mask, postprocessing_kwargs
939+
940+
941+
def postprocess(
942+
self,
943+
image: torch.Tensor,
944+
output_type: str = "pil",
945+
original_image: Optional[PIL.Image.Image] = None,
946+
original_mask: Optional[PIL.Image.Image] = None,
947+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
948+
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
949+
"""
950+
Postprocess the image, optionally apply mask overlay
951+
"""
952+
image = self._image_processor.postprocess(
953+
image,
954+
output_type=output_type,
955+
)
956+
# optionally apply the mask overlay
957+
if crops_coords is not None and (original_image is None or original_mask is None):
958+
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
959+
960+
elif crops_coords is not None:
961+
image = [self._image_processor.apply_overlay(
962+
original_mask, original_image, i, crops_coords
963+
) for i in image]
964+
965+
return image
966+
967+
841968
class VaeImageProcessorLDM3D(VaeImageProcessor):
842969
"""
843970
Image processor for VAE LDM3D.

0 commit comments

Comments
 (0)