Skip to content

Commit 6ea8360

Browse files
[Stable Diffusion Inpainting] Deprecate inpainting pipeline in favor of official one (#903)
* finish * up * Apply suggestions from code review Co-authored-by: Anton Lozhkov <[email protected]> * Update src/diffusers/pipeline_utils.py * Finish Co-authored-by: Anton Lozhkov <[email protected]>
1 parent bd21607 commit 6ea8360

File tree

6 files changed

+432
-2
lines changed

6 files changed

+432
-2
lines changed

src/diffusers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
LDMTextToImagePipeline,
5353
StableDiffusionImg2ImgPipeline,
5454
StableDiffusionInpaintPipeline,
55+
StableDiffusionInpaintPipelineLegacy,
5556
StableDiffusionPipeline,
5657
)
5758
else:

src/diffusers/pipeline_utils.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@
4040
ONNX_WEIGHTS_NAME,
4141
WEIGHTS_NAME,
4242
BaseOutput,
43+
deprecate,
4344
is_transformers_available,
4445
logging,
4546
)
@@ -413,6 +414,25 @@ def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.P
413414
diffusers_module = importlib.import_module(cls.__module__.split(".")[0])
414415
pipeline_class = getattr(diffusers_module, config_dict["_class_name"])
415416

417+
# To be removed in 1.0.0
418+
if pipeline_class.__name__ == "StableDiffusionInpaintPipeline" and version.parse(
419+
version.parse(config_dict["_diffusers_version"]).base_version
420+
) <= version.parse("0.5.1"):
421+
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionInpaintPipelineLegacy
422+
423+
pipeline_class = StableDiffusionInpaintPipelineLegacy
424+
425+
deprecation_message = (
426+
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
427+
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
428+
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
429+
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
430+
f" checkpoint {pretrained_model_name_or_path} to the format of"
431+
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
432+
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
433+
)
434+
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
435+
416436
# some modules can be passed directly to the init
417437
# in this case they are already instantiated in `kwargs`
418438
# extract them here

src/diffusers/pipelines/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .stable_diffusion import (
1717
StableDiffusionImg2ImgPipeline,
1818
StableDiffusionInpaintPipeline,
19+
StableDiffusionInpaintPipelineLegacy,
1920
StableDiffusionPipeline,
2021
)
2122

src/diffusers/pipelines/stable_diffusion/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class StableDiffusionPipelineOutput(BaseOutput):
3131
from .pipeline_stable_diffusion import StableDiffusionPipeline
3232
from .pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
3333
from .pipeline_stable_diffusion_inpaint import StableDiffusionInpaintPipeline
34+
from .pipeline_stable_diffusion_inpaint_legacy import StableDiffusionInpaintPipelineLegacy
3435
from .safety_checker import StableDiffusionSafetyChecker
3536

3637
if is_transformers_available() and is_onnx_available():

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -82,8 +82,6 @@ def __init__(
8282
feature_extractor: CLIPFeatureExtractor,
8383
):
8484
super().__init__()
85-
logger.info("`StableDiffusionInpaintPipeline` is experimental and will very likely change in the future.")
86-
8785
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
8886
deprecation_message = (
8987
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
@@ -223,6 +221,8 @@ def __call__(
223221
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
224222
(nsfw) content, according to the `safety_checker`.
225223
"""
224+
# TODO(Suraj) - adapt to your use case
225+
226226
if isinstance(prompt, str):
227227
batch_size = 1
228228
elif isinstance(prompt, list):

0 commit comments

Comments
 (0)