Skip to content

Commit 7c4b77c

Browse files
committed
StableDiffusionXLControlNetUnionImg2ImgPipeline example
1 parent a0f2874 commit 7c4b77c

File tree

1 file changed

+85
-60
lines changed

1 file changed

+85
-60
lines changed

src/diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py

Lines changed: 85 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -68,74 +68,99 @@
6868
EXAMPLE_DOC_STRING = """
6969
Examples:
7070
```py
71-
>>> # pip install accelerate transformers safetensors diffusers
72-
71+
# !pip install controlnet_aux
72+
>>> from diffusers import (
73+
... StableDiffusionXLControlNetUnionImg2ImgPipeline,
74+
... ControlNetUnionModel,
75+
... AutoencoderKL,
76+
... )
77+
>>> from diffusers.models.controlnets import ControlNetUnionInputProMax
78+
>>> from diffusers.utils import load_image
7379
>>> import torch
74-
>>> import numpy as np
7580
>>> from PIL import Image
76-
77-
>>> from transformers import DPTImageProcessor, DPTForDepthEstimation
78-
>>> from diffusers import ControlNetModel, StableDiffusionXLControlNetImg2ImgPipeline, AutoencoderKL
79-
>>> from diffusers.utils import load_image
80-
81-
82-
>>> depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
83-
>>> feature_extractor = DPTImageProcessor.from_pretrained("Intel/dpt-hybrid-midas")
84-
>>> controlnet = ControlNetModel.from_pretrained(
85-
... "diffusers/controlnet-depth-sdxl-1.0-small",
86-
... variant="fp16",
87-
... use_safetensors=True,
88-
... torch_dtype=torch.float16,
81+
>>> import numpy as np
82+
>>> prompt = "A cat"
83+
>>> # download an image
84+
>>> image = load_image(
85+
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/kandinsky/cat.png"
8986
... )
90-
>>> vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16)
91-
>>> pipe = StableDiffusionXLControlNetImg2ImgPipeline.from_pretrained(
87+
>>> # initialize the models and pipeline
88+
>>> controlnet = ControlNetUnionModel.from_pretrained(
89+
... "brad-twinkl/controlnet-union-sdxl-1.0-promax", torch_dtype=torch.float16
90+
... )
91+
>>> vae = AutoencoderKL.from_pretrained(
92+
... "madebyollin/sdxl-vae-fp16-fix", torch_dtype=torch.float16
93+
... )
94+
>>> pipe = StableDiffusionXLControlNetUnionImg2ImgPipeline.from_pretrained(
9295
... "stabilityai/stable-diffusion-xl-base-1.0",
9396
... controlnet=controlnet,
9497
... vae=vae,
95-
... variant="fp16",
96-
... use_safetensors=True,
9798
... torch_dtype=torch.float16,
98-
... )
99-
>>> pipe.enable_model_cpu_offload()
100-
101-
102-
>>> def get_depth_map(image):
103-
... image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
104-
... with torch.no_grad(), torch.autocast("cuda"):
105-
... depth_map = depth_estimator(image).predicted_depth
106-
107-
... depth_map = torch.nn.functional.interpolate(
108-
... depth_map.unsqueeze(1),
109-
... size=(1024, 1024),
110-
... mode="bicubic",
111-
... align_corners=False,
99+
... ).to("cuda")
100+
>>> # `enable_model_cpu_offload` is not recommended due to multiple generations
101+
>>> height = image.height
102+
>>> width = image.width
103+
>>> ratio = np.sqrt(1024.0 * 1024.0 / (width * height))
104+
>>> # 3 * 3 upscale correspond to 16 * 3 multiply, 2 * 2 correspond to 16 * 2 multiply and so on.
105+
>>> scale_image_factor = 3
106+
>>> base_factor = 16
107+
>>> factor = scale_image_factor * base_factor
108+
>>> W, H = int(width * ratio) // factor * factor, int(height * ratio) // factor * factor
109+
>>> image = image.resize((W, H))
110+
>>> target_width = W // scale_image_factor
111+
>>> target_height = H // scale_image_factor
112+
>>> images = []
113+
>>> crops_coords_list = [
114+
... (0, 0),
115+
... (0, width // 2),
116+
... (height // 2, 0),
117+
... (width // 2, height // 2),
118+
... 0,
119+
... 0,
120+
... 0,
121+
... 0,
122+
... 0,
123+
... ]
124+
>>> for i in range(scale_image_factor):
125+
... for j in range(scale_image_factor):
126+
... left = j * target_width
127+
... top = i * target_height
128+
... right = left + target_width
129+
... bottom = top + target_height
130+
... cropped_image = image.crop((left, top, right, bottom))
131+
... cropped_image = cropped_image.resize((W, H))
132+
... images.append(cropped_image)
133+
>>> # set ControlNetUnion input
134+
>>> result_images = []
135+
>>> for sub_img, crops_coords in zip(images, crops_coords_list):
136+
... union_input = ControlNetUnionInputProMax(
137+
... tile=sub_img,
112138
... )
113-
... depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
114-
... depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
115-
... depth_map = (depth_map - depth_min) / (depth_max - depth_min)
116-
... image = torch.cat([depth_map] * 3, dim=1)
117-
... image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
118-
... image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
119-
... return image
120-
121-
122-
>>> prompt = "A robot, 4k photo"
123-
>>> image = load_image(
124-
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main"
125-
... "/kandinsky/cat.png"
126-
... ).resize((1024, 1024))
127-
>>> controlnet_conditioning_scale = 0.5 # recommended for good generalization
128-
>>> depth_image = get_depth_map(image)
129-
130-
>>> images = pipe(
131-
... prompt,
132-
... image=image,
133-
... control_image=depth_image,
134-
... strength=0.99,
135-
... num_inference_steps=50,
136-
... controlnet_conditioning_scale=controlnet_conditioning_scale,
137-
... ).images
138-
>>> images[0].save(f"robot_cat.png")
139+
... new_width, new_height = W, H
140+
... out = pipe(
141+
... prompt=[prompt] * 1,
142+
... image=sub_img,
143+
... control_image_list=union_input,
144+
... width=new_width,
145+
... height=new_height,
146+
... num_inference_steps=30,
147+
... crops_coords_top_left=(W, H),
148+
... target_size=(W, H),
149+
... original_size=(W * 2, H * 2),
150+
... )
151+
... result_images.append(out.images[0])
152+
>>> new_im = Image.new(
153+
... "RGB", (new_width * scale_image_factor, new_height * scale_image_factor)
154+
... )
155+
>>> new_im.paste(result_images[0], (0, 0))
156+
>>> new_im.paste(result_images[1], (new_width, 0))
157+
>>> new_im.paste(result_images[2], (new_width * 2, 0))
158+
>>> new_im.paste(result_images[3], (0, new_height))
159+
>>> new_im.paste(result_images[4], (new_width, new_height))
160+
>>> new_im.paste(result_images[5], (new_width * 2, new_height))
161+
>>> new_im.paste(result_images[6], (0, new_height * 2))
162+
>>> new_im.paste(result_images[7], (new_width, new_height * 2))
163+
>>> new_im.paste(result_images[8], (new_width * 2, new_height * 2))
139164
```
140165
"""
141166

0 commit comments

Comments
 (0)