55import torch
66
77import PIL
8- from tqdm .auto import tqdm
98from transformers import CLIPFeatureExtractor , CLIPTokenizer
109
1110from ...configuration_utils import FrozenDict
1615from . import StableDiffusionPipelineOutput
1716
1817
19- logger = logging .get_logger (__name__ )
18+ logger = logging .get_logger (__name__ ) # pylint: disable=invalid-name
2019
2120
22- def preprocess_image (image ):
23- w , h = image .size
24- w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
25- image = image .resize ((w , h ), resample = PIL .Image .LANCZOS )
26- image = np .array (image ).astype (np .float32 ) / 255.0
21+ NUM_UNET_INPUT_CHANNELS = 9
22+ NUM_LATENT_CHANNELS = 4
23+
24+
25+ def prepare_mask_and_masked_image (image , mask , latents_shape ):
26+ image = np .array (image .convert ("RGB" ))
2727 image = image [None ].transpose (0 , 3 , 1 , 2 )
28- return 2.0 * image - 1.0
28+ image = image . astype ( np . float32 ) / 127.5 - 1.0
2929
30+ image_mask = np .array (mask .convert ("L" ))
31+ masked_image = image * (image_mask < 127.5 )
3032
31- def preprocess_mask (mask ):
32- mask = mask .convert ("L" )
33- w , h = mask .size
34- w , h = map (lambda x : x - x % 32 , (w , h )) # resize to integer multiple of 32
35- mask = mask .resize ((w // 8 , h // 8 ), resample = PIL .Image .NEAREST )
36- mask = np .array (mask ).astype (np .float32 ) / 255.0
37- mask = np .tile (mask , (4 , 1 , 1 ))
38- mask = mask [None ].transpose (0 , 1 , 2 , 3 ) # what does this step do?
39- mask = 1 - mask # repaint white, keep black
40- return mask
33+ mask = mask .resize ((latents_shape [1 ], latents_shape [0 ]), PIL .Image .NEAREST )
34+ mask = np .array (mask .convert ("L" ))
35+ mask = mask .astype (np .float32 ) / 255.0
36+ mask = mask [None , None ]
37+ mask [mask < 0.5 ] = 0
38+ mask [mask >= 0.5 ] = 1
39+
40+ return mask , masked_image
4141
4242
4343class OnnxStableDiffusionInpaintPipeline (DiffusionPipeline ):
@@ -129,14 +129,16 @@ def __init__(
129129 def __call__ (
130130 self ,
131131 prompt : Union [str , List [str ]],
132- init_image : Union [np .ndarray , PIL .Image .Image ],
133- mask_image : Union [np .ndarray , PIL .Image .Image ],
134- strength : float = 0.8 ,
135- num_inference_steps : Optional [int ] = 50 ,
136- guidance_scale : Optional [float ] = 7.5 ,
132+ image : PIL .Image .Image ,
133+ mask_image : PIL .Image .Image ,
134+ height : int = 512 ,
135+ width : int = 512 ,
136+ num_inference_steps : int = 50 ,
137+ guidance_scale : float = 7.5 ,
137138 negative_prompt : Optional [Union [str , List [str ]]] = None ,
138139 num_images_per_prompt : Optional [int ] = 1 ,
139- eta : Optional [float ] = 0.0 ,
140+ eta : float = 0.0 ,
141+ latents : Optional [np .ndarray ] = None ,
140142 output_type : Optional [str ] = "pil" ,
141143 return_dict : bool = True ,
142144 callback : Optional [Callable [[int , int , np .ndarray ], None ]] = None ,
@@ -149,22 +151,21 @@ def __call__(
149151 Args:
150152 prompt (`str` or `List[str]`):
151153 The prompt or prompts to guide the image generation.
152- init_image (`np.ndarray` or `PIL.Image.Image`):
153- `Image`, or tensor representing an image batch, that will be used as the starting point for the
154- process. This is the image whose masked region will be inpainted.
155- mask_image (`np.ndarray` or `PIL.Image.Image`):
156- `Image`, or tensor representing an image batch, to mask `init_image`. White pixels in the mask will be
157- replaced by noise and therefore repainted, while black pixels will be preserved. If `mask_image` is a
158- PIL image, it will be converted to a single channel (luminance) before use. If it's a tensor, it should
159- contain one color channel (L) instead of 3, so the expected shape would be `(B, H, W, 1)`.
160- strength (`float`, *optional*, defaults to 0.8):
161- Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1. When `strength`
162- is 1, the denoising process will be run on the masked area for the full number of iterations specified
163- in `num_inference_steps`. `init_image` will be used as a reference for the masked area, adding more
164- noise to that region the larger the `strength`. If `strength` is 0, no inpainting will occur.
154+ image (`PIL.Image.Image`):
155+ `Image`, or tensor representing an image batch which will be inpainted, *i.e.* parts of the image will
156+ be masked out with `mask_image` and repainted according to `prompt`.
157+ mask_image (`PIL.Image.Image`):
158+ `Image`, or tensor representing an image batch, to mask `image`. White pixels in the mask will be
159+ repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
160+ to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
161+ instead of 3, so the expected shape would be `(B, H, W, 1)`.
162+ height (`int`, *optional*, defaults to 512):
163+ The height in pixels of the generated image.
164+ width (`int`, *optional*, defaults to 512):
165+ The width in pixels of the generated image.
165166 num_inference_steps (`int`, *optional*, defaults to 50):
166- The reference number of denoising steps. More denoising steps usually lead to a higher quality image at
167- the expense of slower inference. This parameter will be modulated by `strength`, as explained above .
167+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
168+ expense of slower inference.
168169 guidance_scale (`float`, *optional*, defaults to 7.5):
169170 Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
170171 `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -179,6 +180,10 @@ def __call__(
179180 eta (`float`, *optional*, defaults to 0.0):
180181 Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
181182 [`schedulers.DDIMScheduler`], will be ignored for others.
183+ latents (`np.ndarray`, *optional*):
184+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
185+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
186+ tensor will ge generated by sampling using the supplied random `generator`.
182187 output_type (`str`, *optional*, defaults to `"pil"`):
183188 The output format of the generate image. Choose between
184189 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -206,8 +211,8 @@ def __call__(
206211 else :
207212 raise ValueError (f"`prompt` has to be of type `str` or `list` but is { type (prompt )} " )
208213
209- if strength < 0 or strength > 1 :
210- raise ValueError (f"The value of strength should in [0.0, 1.0] but is { strength } " )
214+ if height % 8 != 0 or width % 8 != 0 :
215+ raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } . " )
211216
212217 if (callback_steps is None ) or (
213218 callback_steps is not None and (not isinstance (callback_steps , int ) or callback_steps <= 0 )
@@ -285,41 +290,46 @@ def __call__(
285290 # to avoid doing two forward passes
286291 text_embeddings = np .concatenate ([uncond_embeddings , text_embeddings ])
287292
288- # preprocess image
289- if not isinstance (init_image , torch .FloatTensor ):
290- init_image = preprocess_image (init_image )
293+ num_channels_latents = NUM_LATENT_CHANNELS
294+ latents_shape = (batch_size * num_images_per_prompt , num_channels_latents , height // 8 , width // 8 )
295+ latents_dtype = text_embeddings .dtype
296+ if latents is None :
297+ latents = np .random .randn (* latents_shape ).astype (latents_dtype )
298+ else :
299+ if latents .shape != latents_shape :
300+ raise ValueError (f"Unexpected latents shape, got { latents .shape } , expected { latents_shape } " )
291301
292- # encode the init image into latents and scale the latents
293- init_latents = self .vae_encoder (sample = init_image )[0 ]
294- init_latents = 0.18215 * init_latents
302+ # prepare mask and masked_image
303+ mask , masked_image = prepare_mask_and_masked_image (image , mask_image , latents_shape [- 2 :])
304+ mask = mask .astype (latents .dtype )
305+ masked_image = masked_image .astype (latents .dtype )
295306
296- # Expand init_latents for batch_size and num_images_per_prompt
297- init_latents = np .concatenate ([init_latents ] * batch_size * num_images_per_prompt , axis = 0 )
298- init_latents_orig = init_latents
307+ masked_image_latents = self .vae_encoder (sample = masked_image )[0 ]
308+ masked_image_latents = 0.18215 * masked_image_latents
299309
300- # preprocess mask
301- if not isinstance ( mask_image , np . ndarray ):
302- mask_image = preprocess_mask ( mask_image )
303- mask = np . concatenate ([ mask_image ] * batch_size * num_images_per_prompt )
310+ mask = np . concatenate ([ mask ] * 2 ) if do_classifier_free_guidance else mask
311+ masked_image_latents = (
312+ np . concatenate ([ masked_image_latents ] * 2 ) if do_classifier_free_guidance else masked_image_latents
313+ )
304314
305- # check sizes
306- if not mask .shape == init_latents .shape :
307- raise ValueError ("The mask and init_image should be the same size!" )
315+ num_channels_mask = mask .shape [1 ]
316+ num_channels_masked_image = masked_image_latents .shape [1 ]
308317
309- # get the original timestep using init_timestep
310- offset = self .scheduler .config .get ("steps_offset" , 0 )
311- init_timestep = int (num_inference_steps * strength ) + offset
312- init_timestep = min (init_timestep , num_inference_steps )
318+ unet_input_channels = NUM_UNET_INPUT_CHANNELS
319+ if num_channels_latents + num_channels_mask + num_channels_masked_image != unet_input_channels :
320+ raise ValueError (
321+ "Incorrect configuration settings! The config of `pipeline.unet` expects"
322+ f" { unet_input_channels } but received `num_channels_latents`: { num_channels_latents } +"
323+ f" `num_channels_mask`: { num_channels_mask } + `num_channels_masked_image`: { num_channels_masked_image } "
324+ f" = { num_channels_latents + num_channels_masked_image + num_channels_mask } . Please verify the config of"
325+ " `pipeline.unet` or your `mask_image` or `image` input."
326+ )
313327
314- timesteps = self . scheduler . timesteps . numpy ()[ - init_timestep ]
315- timesteps = np . array ([ timesteps ] * batch_size * num_images_per_prompt )
328+ # set timesteps
329+ self . scheduler . set_timesteps ( num_inference_steps )
316330
317- # add noise to latents using the timesteps
318- noise = np .random .randn (* init_latents .shape ).astype (np .float32 )
319- init_latents = self .scheduler .add_noise (
320- torch .from_numpy (init_latents ), torch .from_numpy (noise ), torch .from_numpy (timesteps )
321- )
322- init_latents = init_latents .numpy ()
331+ # scale the initial noise by the standard deviation required by the scheduler
332+ latents = latents * self .scheduler .init_noise_sigma
323333
324334 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
325335 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -330,15 +340,13 @@ def __call__(
330340 if accepts_eta :
331341 extra_step_kwargs ["eta" ] = eta
332342
333- latents = init_latents
334-
335- t_start = max (num_inference_steps - init_timestep + offset , 0 )
336- timesteps = self .scheduler .timesteps [t_start :].numpy ()
337-
338- for i , t in tqdm (enumerate (timesteps )):
343+ for i , t in enumerate (self .progress_bar (self .scheduler .timesteps )):
339344 # expand the latents if we are doing classifier free guidance
340345 latent_model_input = np .concatenate ([latents ] * 2 ) if do_classifier_free_guidance else latents
341- latent_model_input = self .scheduler .scale_model_input (latent_model_input , t )
346+ # concat latents, mask, masked_image_latnets in the channel dimension
347+ latent_model_input = np .concatenate ([latent_model_input , mask , masked_image_latents ], axis = 1 )
348+ latent_model_input = self .scheduler .scale_model_input (torch .from_numpy (latent_model_input ), t )
349+ latent_model_input = latent_model_input .numpy ()
342350
343351 # predict the noise residual
344352 noise_pred = self .unet (
@@ -353,12 +361,6 @@ def __call__(
353361 # compute the previous noisy sample x_t -> x_t-1
354362 latents = self .scheduler .step (noise_pred , t , latents , ** extra_step_kwargs ).prev_sample
355363 latents = latents .numpy ()
356- # masking
357- init_latents_proper = self .scheduler .add_noise (
358- torch .from_numpy (init_latents_orig ), torch .from_numpy (noise ), torch .tensor ([t ])
359- )
360-
361- latents = (init_latents_proper * mask ) + (latents * (1 - mask ))
362364
363365 # call the callback, if provided
364366 if callback is not None and i % callback_steps == 0 :
0 commit comments