19
19
import torch
20
20
21
21
import PIL
22
- from packaging import version
23
22
from transformers import CLIPFeatureExtractor , CLIPTokenizer
24
23
25
24
from ...configuration_utils import FrozenDict
@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
91
90
safety_checker : OnnxRuntimeModel
92
91
feature_extractor : CLIPFeatureExtractor
93
92
93
+ _optional_components = ["safety_checker" , "feature_extractor" ]
94
+
94
95
def __init__ (
95
96
self ,
96
97
vae_encoder : OnnxRuntimeModel ,
@@ -149,27 +150,6 @@ def __init__(
149
150
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
150
151
)
151
152
152
- is_unet_version_less_0_9_0 = hasattr (unet .config , "_diffusers_version" ) and version .parse (
153
- version .parse (unet .config ._diffusers_version ).base_version
154
- ) < version .parse ("0.9.0.dev0" )
155
- is_unet_sample_size_less_64 = hasattr (unet .config , "sample_size" ) and unet .config .sample_size < 64
156
- if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64 :
157
- deprecation_message = (
158
- "The configuration file of the unet has set the default `sample_size` to smaller than"
159
- " 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
160
- " following: \n - CompVis/stable-diffusion-v1-4 \n - CompVis/stable-diffusion-v1-3 \n -"
161
- " CompVis/stable-diffusion-v1-2 \n - CompVis/stable-diffusion-v1-1 \n - runwayml/stable-diffusion-v1-5"
162
- " \n - runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
163
- " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
164
- " in the config might lead to incorrect results in future versions. If you have downloaded this"
165
- " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
166
- " the `unet/config.json` file"
167
- )
168
- deprecate ("sample_size<64" , "1.0.0" , deprecation_message , standard_warn = False )
169
- new_config = dict (unet .config )
170
- new_config ["sample_size" ] = 64
171
- unet ._internal_dict = FrozenDict (new_config )
172
-
173
153
self .register_modules (
174
154
vae_encoder = vae_encoder ,
175
155
vae_decoder = vae_decoder ,
@@ -180,7 +160,6 @@ def __init__(
180
160
safety_checker = safety_checker ,
181
161
feature_extractor = feature_extractor ,
182
162
)
183
- self .vae_scale_factor = 2 ** (len (self .vae .config .block_out_channels ) - 1 )
184
163
self .register_to_config (requires_safety_checker = requires_safety_checker )
185
164
186
165
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@@ -267,8 +246,8 @@ def __call__(
267
246
prompt : Union [str , List [str ]],
268
247
image : PIL .Image .Image ,
269
248
mask_image : PIL .Image .Image ,
270
- height : Optional [int ] = None ,
271
- width : Optional [int ] = None ,
249
+ height : Optional [int ] = 512 ,
250
+ width : Optional [int ] = 512 ,
272
251
num_inference_steps : int = 50 ,
273
252
guidance_scale : float = 7.5 ,
274
253
negative_prompt : Optional [Union [str , List [str ]]] = None ,
@@ -296,9 +275,9 @@ def __call__(
296
275
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
297
276
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
298
277
instead of 3, so the expected shape would be `(B, H, W, 1)`.
299
- height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor ):
278
+ height (`int`, *optional*, defaults to 512 ):
300
279
The height in pixels of the generated image.
301
- width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor ):
280
+ width (`int`, *optional*, defaults to 512 ):
302
281
The width in pixels of the generated image.
303
282
num_inference_steps (`int`, *optional*, defaults to 50):
304
283
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -343,9 +322,6 @@ def __call__(
343
322
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
344
323
(nsfw) content, according to the `safety_checker`.
345
324
"""
346
- # 0. Default height and width to unet
347
- height = height or self .unet .config .sample_size * self .vae_scale_factor
348
- width = width or self .unet .config .sample_size * self .vae_scale_factor
349
325
350
326
if isinstance (prompt , str ):
351
327
batch_size = 1
@@ -381,12 +357,7 @@ def __call__(
381
357
)
382
358
383
359
num_channels_latents = NUM_LATENT_CHANNELS
384
- latents_shape = (
385
- batch_size * num_images_per_prompt ,
386
- num_channels_latents ,
387
- height // self .vae_scale_factor ,
388
- width // self .vae_scale_factor ,
389
- )
360
+ latents_shape = (batch_size * num_images_per_prompt , num_channels_latents , height // 8 , width // 8 )
390
361
latents_dtype = text_embeddings .dtype
391
362
if latents is None :
392
363
latents = generator .randn (* latents_shape ).astype (latents_dtype )
0 commit comments