Skip to content

Commit 86aa747

Browse files
authored
Fix ONNX conversion and inference (#1416)
1 parent d52388f commit 86aa747

File tree

4 files changed

+18
-94
lines changed

4 files changed

+18
-94
lines changed

scripts/convert_stable_diffusion_checkpoint_to_onnx.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -215,8 +215,10 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
215215
)
216216
del pipeline.safety_checker
217217
safety_checker = OnnxRuntimeModel.from_pretrained(output_path / "safety_checker")
218+
feature_extractor = pipeline.feature_extractor
218219
else:
219220
safety_checker = None
221+
feature_extractor = None
220222

221223
onnx_pipeline = OnnxStableDiffusionPipeline(
222224
vae_encoder=OnnxRuntimeModel.from_pretrained(output_path / "vae_encoder"),
@@ -226,7 +228,8 @@ def convert_models(model_path: str, output_path: str, opset: int, fp16: bool = F
226228
unet=OnnxRuntimeModel.from_pretrained(output_path / "unet"),
227229
scheduler=pipeline.scheduler,
228230
safety_checker=safety_checker,
229-
feature_extractor=pipeline.feature_extractor,
231+
feature_extractor=feature_extractor,
232+
requires_safety_checker=safety_checker is not None,
230233
)
231234

232235
onnx_pipeline.save_pretrained(output_path)

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
import numpy as np
1919
import torch
2020

21-
from packaging import version
2221
from transformers import CLIPFeatureExtractor, CLIPTokenizer
2322

2423
from ...configuration_utils import FrozenDict
@@ -42,6 +41,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
4241
safety_checker: OnnxRuntimeModel
4342
feature_extractor: CLIPFeatureExtractor
4443

44+
_optional_components = ["safety_checker", "feature_extractor"]
45+
4546
def __init__(
4647
self,
4748
vae_encoder: OnnxRuntimeModel,
@@ -99,27 +100,6 @@ def __init__(
99100
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
100101
)
101102

102-
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
103-
version.parse(unet.config._diffusers_version).base_version
104-
) < version.parse("0.9.0.dev0")
105-
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
106-
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
107-
deprecation_message = (
108-
"The configuration file of the unet has set the default `sample_size` to smaller than"
109-
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
110-
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
111-
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
112-
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
113-
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
114-
" in the config might lead to incorrect results in future versions. If you have downloaded this"
115-
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
116-
" the `unet/config.json` file"
117-
)
118-
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
119-
new_config = dict(unet.config)
120-
new_config["sample_size"] = 64
121-
unet._internal_dict = FrozenDict(new_config)
122-
123103
self.register_modules(
124104
vae_encoder=vae_encoder,
125105
vae_decoder=vae_decoder,
@@ -130,7 +110,6 @@ def __init__(
130110
safety_checker=safety_checker,
131111
feature_extractor=feature_extractor,
132112
)
133-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
134113
self.register_to_config(requires_safety_checker=requires_safety_checker)
135114

136115
def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guidance, negative_prompt):
@@ -213,8 +192,8 @@ def _encode_prompt(self, prompt, num_images_per_prompt, do_classifier_free_guida
213192
def __call__(
214193
self,
215194
prompt: Union[str, List[str]],
216-
height: Optional[int] = None,
217-
width: Optional[int] = None,
195+
height: Optional[int] = 512,
196+
width: Optional[int] = 512,
218197
num_inference_steps: Optional[int] = 50,
219198
guidance_scale: Optional[float] = 7.5,
220199
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -228,10 +207,6 @@ def __call__(
228207
callback_steps: Optional[int] = 1,
229208
**kwargs,
230209
):
231-
# 0. Default height and width to unet
232-
height = height or self.unet.config.sample_size * self.vae_scale_factor
233-
width = width or self.unet.config.sample_size * self.vae_scale_factor
234-
235210
if isinstance(prompt, str):
236211
batch_size = 1
237212
elif isinstance(prompt, list):
@@ -264,12 +239,7 @@ def __call__(
264239

265240
# get the initial random noise unless the user supplied it
266241
latents_dtype = text_embeddings.dtype
267-
latents_shape = (
268-
batch_size * num_images_per_prompt,
269-
4,
270-
height // self.vae_scale_factor,
271-
width // self.vae_scale_factor,
272-
)
242+
latents_shape = (batch_size * num_images_per_prompt, 4, height // 8, width // 8)
273243
if latents is None:
274244
latents = generator.randn(*latents_shape).astype(latents_dtype)
275245
elif latents.shape != latents_shape:

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020

2121
import PIL
22-
from packaging import version
2322
from transformers import CLIPFeatureExtractor, CLIPTokenizer
2423

2524
from ...configuration_utils import FrozenDict
@@ -78,6 +77,8 @@ class OnnxStableDiffusionImg2ImgPipeline(DiffusionPipeline):
7877
safety_checker: OnnxRuntimeModel
7978
feature_extractor: CLIPFeatureExtractor
8079

80+
_optional_components = ["safety_checker", "feature_extractor"]
81+
8182
def __init__(
8283
self,
8384
vae_encoder: OnnxRuntimeModel,
@@ -135,27 +136,6 @@ def __init__(
135136
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
136137
)
137138

138-
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
139-
version.parse(unet.config._diffusers_version).base_version
140-
) < version.parse("0.9.0.dev0")
141-
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
142-
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
143-
deprecation_message = (
144-
"The configuration file of the unet has set the default `sample_size` to smaller than"
145-
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
146-
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
147-
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
148-
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
149-
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
150-
" in the config might lead to incorrect results in future versions. If you have downloaded this"
151-
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
152-
" the `unet/config.json` file"
153-
)
154-
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
155-
new_config = dict(unet.config)
156-
new_config["sample_size"] = 64
157-
unet._internal_dict = FrozenDict(new_config)
158-
159139
self.register_modules(
160140
vae_encoder=vae_encoder,
161141
vae_decoder=vae_decoder,

src/diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py

Lines changed: 7 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import torch
2020

2121
import PIL
22-
from packaging import version
2322
from transformers import CLIPFeatureExtractor, CLIPTokenizer
2423

2524
from ...configuration_utils import FrozenDict
@@ -91,6 +90,8 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
9190
safety_checker: OnnxRuntimeModel
9291
feature_extractor: CLIPFeatureExtractor
9392

93+
_optional_components = ["safety_checker", "feature_extractor"]
94+
9495
def __init__(
9596
self,
9697
vae_encoder: OnnxRuntimeModel,
@@ -149,27 +150,6 @@ def __init__(
149150
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
150151
)
151152

152-
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
153-
version.parse(unet.config._diffusers_version).base_version
154-
) < version.parse("0.9.0.dev0")
155-
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
156-
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
157-
deprecation_message = (
158-
"The configuration file of the unet has set the default `sample_size` to smaller than"
159-
" 64 which seems highly unlikely .If you're checkpoint is a fine-tuned version of any of the"
160-
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
161-
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
162-
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
163-
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
164-
" in the config might lead to incorrect results in future versions. If you have downloaded this"
165-
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
166-
" the `unet/config.json` file"
167-
)
168-
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
169-
new_config = dict(unet.config)
170-
new_config["sample_size"] = 64
171-
unet._internal_dict = FrozenDict(new_config)
172-
173153
self.register_modules(
174154
vae_encoder=vae_encoder,
175155
vae_decoder=vae_decoder,
@@ -180,7 +160,6 @@ def __init__(
180160
safety_checker=safety_checker,
181161
feature_extractor=feature_extractor,
182162
)
183-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
184163
self.register_to_config(requires_safety_checker=requires_safety_checker)
185164

186165
# Copied from diffusers.pipelines.stable_diffusion.pipeline_onnx_stable_diffusion.OnnxStableDiffusionPipeline._encode_prompt
@@ -267,8 +246,8 @@ def __call__(
267246
prompt: Union[str, List[str]],
268247
image: PIL.Image.Image,
269248
mask_image: PIL.Image.Image,
270-
height: Optional[int] = None,
271-
width: Optional[int] = None,
249+
height: Optional[int] = 512,
250+
width: Optional[int] = 512,
272251
num_inference_steps: int = 50,
273252
guidance_scale: float = 7.5,
274253
negative_prompt: Optional[Union[str, List[str]]] = None,
@@ -296,9 +275,9 @@ def __call__(
296275
repainted, while black pixels will be preserved. If `mask_image` is a PIL image, it will be converted
297276
to a single channel (luminance) before use. If it's a tensor, it should contain one color channel (L)
298277
instead of 3, so the expected shape would be `(B, H, W, 1)`.
299-
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
278+
height (`int`, *optional*, defaults to 512):
300279
The height in pixels of the generated image.
301-
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
280+
width (`int`, *optional*, defaults to 512):
302281
The width in pixels of the generated image.
303282
num_inference_steps (`int`, *optional*, defaults to 50):
304283
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
@@ -343,9 +322,6 @@ def __call__(
343322
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
344323
(nsfw) content, according to the `safety_checker`.
345324
"""
346-
# 0. Default height and width to unet
347-
height = height or self.unet.config.sample_size * self.vae_scale_factor
348-
width = width or self.unet.config.sample_size * self.vae_scale_factor
349325

350326
if isinstance(prompt, str):
351327
batch_size = 1
@@ -381,12 +357,7 @@ def __call__(
381357
)
382358

383359
num_channels_latents = NUM_LATENT_CHANNELS
384-
latents_shape = (
385-
batch_size * num_images_per_prompt,
386-
num_channels_latents,
387-
height // self.vae_scale_factor,
388-
width // self.vae_scale_factor,
389-
)
360+
latents_shape = (batch_size * num_images_per_prompt, num_channels_latents, height // 8, width // 8)
390361
latents_dtype = text_embeddings.dtype
391362
if latents is None:
392363
latents = generator.randn(*latents_shape).astype(latents_dtype)

0 commit comments

Comments
 (0)