Skip to content

Commit e7eadad

Browse files
committed
StableDiffusionXLControlNetUnionInpaintPipeline example
1 parent eb0524d commit e7eadad

File tree

1 file changed

+40
-50
lines changed

1 file changed

+40
-50
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py

Lines changed: 40 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -81,56 +81,46 @@ def retrieve_latents(
8181
EXAMPLE_DOC_STRING = """
8282
Examples:
8383
```py
84-
>>> # !pip install transformers accelerate
85-
>>> from diffusers import StableDiffusionXLControlNetInpaintPipeline, ControlNetModel, DDIMScheduler
86-
>>> from diffusers.utils import load_image
87-
>>> from PIL import Image
88-
>>> import numpy as np
89-
>>> import torch
90-
91-
>>> init_image = load_image(
92-
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy.png"
93-
... )
94-
>>> init_image = init_image.resize((1024, 1024))
95-
96-
>>> generator = torch.Generator(device="cpu").manual_seed(1)
97-
98-
>>> mask_image = load_image(
99-
... "https://huggingface.co/datasets/diffusers/test-arrays/resolve/main/stable_diffusion_inpaint/boy_mask.png"
100-
... )
101-
>>> mask_image = mask_image.resize((1024, 1024))
102-
103-
104-
>>> def make_canny_condition(image):
105-
... image = np.array(image)
106-
... image = cv2.Canny(image, 100, 200)
107-
... image = image[:, :, None]
108-
... image = np.concatenate([image, image, image], axis=2)
109-
... image = Image.fromarray(image)
110-
... return image
111-
112-
113-
>>> control_image = make_canny_condition(init_image)
114-
115-
>>> controlnet = ControlNetModel.from_pretrained(
116-
... "diffusers/controlnet-canny-sdxl-1.0", torch_dtype=torch.float16
117-
... )
118-
>>> pipe = StableDiffusionXLControlNetInpaintPipeline.from_pretrained(
119-
... "stabilityai/stable-diffusion-xl-base-1.0", controlnet=controlnet, torch_dtype=torch.float16
120-
... )
121-
122-
>>> pipe.enable_model_cpu_offload()
123-
124-
>>> # generate image
125-
>>> image = pipe(
126-
... "a handsome man with ray-ban sunglasses",
127-
... num_inference_steps=20,
128-
... generator=generator,
129-
... eta=1.0,
130-
... image=init_image,
131-
... mask_image=mask_image,
132-
... control_image=control_image,
133-
... ).images[0]
84+
from diffusers import StableDiffusionXLControlNetUnionInpaintPipeline, ControlNetUnionModel, AutoencoderKL
85+
from diffusers.models.controlnets import ControlNetUnionInputProMax
86+
from diffusers.utils import load_image
87+
import torch
88+
import numpy as np
89+
from PIL import Image
90+
prompt = "A cat"
91+
# download an image
92+
image = load_image(
93+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo.png"
94+
).resize((1024, 1024))
95+
mask = load_image(
96+
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/in_paint/overture-creations-5sI6fQgYIuo_mask.png"
97+
).resize((1024, 1024))
98+
# initialize the models and pipeline
99+
controlnet = ControlNetUnionModel.from_pretrained(
100+
"brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
101+
)
102+
vae = AutoencoderKL.from_pretrained(
103+
"madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
104+
)
105+
pipe = StableDiffusionXLControlNetUnionInpaintPipeline.from_pretrained(
106+
"stabilityai/stable-diffusion-xl-base-1.0",
107+
controlnet=controlnet,
108+
vae=vae,
109+
torch_dtype=torch.float16,
110+
variant="fp16",
111+
)
112+
pipe.enable_model_cpu_offload()
113+
controlnet_img = image.copy()
114+
controlnet_img_np = np.array(controlnet_img)
115+
mask_np = np.array(mask)
116+
controlnet_img_np[mask_np > 0] = 0
117+
controlnet_img = Image.fromarray(controlnet_img_np)
118+
union_input = ControlNetUnionInputProMax(
119+
repaint=controlnet_img,
120+
)
121+
# generate image
122+
image = pipe(prompt, image=image, mask_image=mask, control_image_list=union_input).images[0]
123+
image.save("inpaint.png")
134124
```
135125
"""
136126

0 commit comments

Comments
 (0)