Skip to content

Commit e973de6

Browse files
committed
fix contro;net inpaint preprocess
1 parent db94ca8 commit e973de6

File tree

1 file changed

+12
-3
lines changed

1 file changed

+12
-3
lines changed

src/diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_modular.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1299,6 +1299,7 @@ def intermediates_inputs(self) -> List[str]:
12991299
"masked_image_latents",
13001300
"noise",
13011301
"image_latents",
1302+
"crops_coords",
13021303
]
13031304

13041305
@property
@@ -1350,6 +1351,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
13501351
masked_image_latents = state.get_intermediate("masked_image_latents")
13511352
noise = state.get_intermediate("noise")
13521353
image_latents = state.get_intermediate("image_latents")
1354+
crops_coords = state.get_intermediate("crops_coords")
13531355
num_channels_unet = pipeline.unet.config.in_channels
13541356
if num_channels_unet == 9:
13551357
# default case for runwayml/stable-diffusion-inpainting
@@ -1409,6 +1411,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
14091411
num_images_per_prompt=num_images_per_prompt,
14101412
device=device,
14111413
dtype=controlnet.dtype,
1414+
crops_coords=crops_coords,
14121415
)
14131416
elif isinstance(controlnet, MultiControlNetModel):
14141417
control_images = []
@@ -1422,6 +1425,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
14221425
num_images_per_prompt=num_images_per_prompt,
14231426
device=device,
14241427
dtype=controlnet.dtype,
1428+
crops_coords=crops_coords,
14251429
)
14261430

14271431
control_images.append(control_image)
@@ -1947,7 +1951,8 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
19471951
return image_embeds, uncond_image_embeds
19481952

19491953
# Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
1950-
# return image without apply any guidance
1954+
# 1. return image without apply any guidance
1955+
# 2. add crops_coords and resize_mode to preprocess()
19511956
def prepare_control_image(
19521957
self,
19531958
image,
@@ -1957,8 +1962,12 @@ def prepare_control_image(
19571962
num_images_per_prompt,
19581963
device,
19591964
dtype,
1960-
):
1961-
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
1965+
crops_coords=None,
1966+
):
1967+
if crops_coords is not None:
1968+
image = self.control_image_processor.preprocess(image, height=height, width=width, crops_coords=crops_coords, resize_mode="fill").to(dtype=torch.float32)
1969+
else:
1970+
image = self.control_image_processor.preprocess(image, height=height, width=width).to(dtype=torch.float32)
19621971
image_batch_size = image.shape[0]
19631972

19641973
if image_batch_size == 1:

0 commit comments

Comments
 (0)