@@ -1299,6 +1299,7 @@ def intermediates_inputs(self) -> List[str]:
12991299 "masked_image_latents" ,
13001300 "noise" ,
13011301 "image_latents" ,
1302+ "crops_coords" ,
13021303 ]
13031304
13041305 @property
@@ -1350,6 +1351,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
13501351 masked_image_latents = state .get_intermediate ("masked_image_latents" )
13511352 noise = state .get_intermediate ("noise" )
13521353 image_latents = state .get_intermediate ("image_latents" )
1354+ crops_coords = state .get_intermediate ("crops_coords" )
13531355 num_channels_unet = pipeline .unet .config .in_channels
13541356 if num_channels_unet == 9 :
13551357 # default case for runwayml/stable-diffusion-inpainting
@@ -1409,6 +1411,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
14091411 num_images_per_prompt = num_images_per_prompt ,
14101412 device = device ,
14111413 dtype = controlnet .dtype ,
1414+ crops_coords = crops_coords ,
14121415 )
14131416 elif isinstance (controlnet , MultiControlNetModel ):
14141417 control_images = []
@@ -1422,6 +1425,7 @@ def __call__(self, pipeline, state: PipelineState) -> PipelineState:
14221425 num_images_per_prompt = num_images_per_prompt ,
14231426 device = device ,
14241427 dtype = controlnet .dtype ,
1428+ crops_coords = crops_coords ,
14251429 )
14261430
14271431 control_images .append (control_image )
@@ -1947,7 +1951,8 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state
19471951 return image_embeds , uncond_image_embeds
19481952
19491953 # Modified from diffusers.pipelines.controlnet.pipeline_controlnet_sd_xl.StableDiffusionXLControlNetPipeline.prepare_image
1950- # return image without apply any guidance
1954+ # 1. return image without apply any guidance
1955+ # 2. add crops_coords and resize_mode to preprocess()
19511956 def prepare_control_image (
19521957 self ,
19531958 image ,
@@ -1957,8 +1962,12 @@ def prepare_control_image(
19571962 num_images_per_prompt ,
19581963 device ,
19591964 dtype ,
1960- ):
1961- image = self .control_image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
1965+ crops_coords = None ,
1966+ ):
1967+ if crops_coords is not None :
1968+ image = self .control_image_processor .preprocess (image , height = height , width = width , crops_coords = crops_coords , resize_mode = "fill" ).to (dtype = torch .float32 )
1969+ else :
1970+ image = self .control_image_processor .preprocess (image , height = height , width = width ).to (dtype = torch .float32 )
19621971 image_batch_size = image .shape [0 ]
19631972
19641973 if image_batch_size == 1 :
0 commit comments