Skip to content

Commit 68bd693

Browse files
guiyrthlky
andauthored
IP-Adapter support for StableDiffusion3ControlNetPipeline (#10363)
* IP-Adapter support for `StableDiffusion3ControlNetPipeline` * Update src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py Co-authored-by: hlky <[email protected]> --------- Co-authored-by: hlky <[email protected]>
1 parent f4fdb3a commit 68bd693

File tree

3 files changed

+122
-6
lines changed

3 files changed

+122
-6
lines changed

src/diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py

Lines changed: 118 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,16 @@
1717

1818
import torch
1919
from transformers import (
20+
BaseImageProcessor,
2021
CLIPTextModelWithProjection,
2122
CLIPTokenizer,
23+
PreTrainedModel,
2224
T5EncoderModel,
2325
T5TokenizerFast,
2426
)
2527

2628
from ...image_processor import PipelineImageInput, VaeImageProcessor
27-
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
29+
from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
2830
from ...models.autoencoders import AutoencoderKL
2931
from ...models.controlnets.controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
3032
from ...models.transformers import SD3Transformer2DModel
@@ -138,7 +140,9 @@ def retrieve_timesteps(
138140
return timesteps, num_inference_steps
139141

140142

141-
class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
143+
class StableDiffusion3ControlNetPipeline(
144+
DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin, SD3IPAdapterMixin
145+
):
142146
r"""
143147
Args:
144148
transformer ([`SD3Transformer2DModel`]):
@@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
174178
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
175179
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
176180
additional conditioning.
181+
image_encoder (`PreTrainedModel`, *optional*):
182+
Pre-trained Vision Model for IP Adapter.
183+
feature_extractor (`BaseImageProcessor`, *optional*):
184+
Image processor for IP Adapter.
177185
"""
178186

179-
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
180-
_optional_components = []
187+
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
188+
_optional_components = ["image_encoder", "feature_extractor"]
181189
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
182190

183191
def __init__(
@@ -194,6 +202,8 @@ def __init__(
194202
controlnet: Union[
195203
SD3ControlNetModel, List[SD3ControlNetModel], Tuple[SD3ControlNetModel], SD3MultiControlNetModel
196204
],
205+
image_encoder: PreTrainedModel = None,
206+
feature_extractor: BaseImageProcessor = None,
197207
):
198208
super().__init__()
199209
if isinstance(controlnet, (list, tuple)):
@@ -223,6 +233,8 @@ def __init__(
223233
transformer=transformer,
224234
scheduler=scheduler,
225235
controlnet=controlnet,
236+
image_encoder=image_encoder,
237+
feature_extractor=feature_extractor,
226238
)
227239
self.vae_scale_factor = (
228240
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
@@ -727,6 +739,84 @@ def num_timesteps(self):
727739
def interrupt(self):
728740
return self._interrupt
729741

742+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
743+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
744+
"""Encodes the given image into a feature representation using a pre-trained image encoder.
745+
746+
Args:
747+
image (`PipelineImageInput`):
748+
Input image to be encoded.
749+
device: (`torch.device`):
750+
Torch device.
751+
752+
Returns:
753+
`torch.Tensor`: The encoded image feature representation.
754+
"""
755+
if not isinstance(image, torch.Tensor):
756+
image = self.feature_extractor(image, return_tensors="pt").pixel_values
757+
758+
image = image.to(device=device, dtype=self.dtype)
759+
760+
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
761+
762+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
763+
def prepare_ip_adapter_image_embeds(
764+
self,
765+
ip_adapter_image: Optional[PipelineImageInput] = None,
766+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
767+
device: Optional[torch.device] = None,
768+
num_images_per_prompt: int = 1,
769+
do_classifier_free_guidance: bool = True,
770+
) -> torch.Tensor:
771+
"""Prepares image embeddings for use in the IP-Adapter.
772+
773+
Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
774+
775+
Args:
776+
ip_adapter_image (`PipelineImageInput`, *optional*):
777+
The input image to extract features from for IP-Adapter.
778+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
779+
Precomputed image embeddings.
780+
device: (`torch.device`, *optional*):
781+
Torch device.
782+
num_images_per_prompt (`int`, defaults to 1):
783+
Number of images that should be generated per prompt.
784+
do_classifier_free_guidance (`bool`, defaults to True):
785+
Whether to use classifier free guidance or not.
786+
"""
787+
device = device or self._execution_device
788+
789+
if ip_adapter_image_embeds is not None:
790+
if do_classifier_free_guidance:
791+
single_negative_image_embeds, single_image_embeds = ip_adapter_image_embeds.chunk(2)
792+
else:
793+
single_image_embeds = ip_adapter_image_embeds
794+
elif ip_adapter_image is not None:
795+
single_image_embeds = self.encode_image(ip_adapter_image, device)
796+
if do_classifier_free_guidance:
797+
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
798+
else:
799+
raise ValueError("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided.")
800+
801+
image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
802+
803+
if do_classifier_free_guidance:
804+
negative_image_embeds = torch.cat([single_negative_image_embeds] * num_images_per_prompt, dim=0)
805+
image_embeds = torch.cat([negative_image_embeds, image_embeds], dim=0)
806+
807+
return image_embeds.to(device=device)
808+
809+
# Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
810+
def enable_sequential_cpu_offload(self, *args, **kwargs):
811+
if self.image_encoder is not None and "image_encoder" not in self._exclude_from_cpu_offload:
812+
logger.warning(
813+
"`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
814+
"`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
815+
"`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
816+
)
817+
818+
super().enable_sequential_cpu_offload(*args, **kwargs)
819+
730820
@torch.no_grad()
731821
@replace_example_docstring(EXAMPLE_DOC_STRING)
732822
def __call__(
@@ -754,6 +844,8 @@ def __call__(
754844
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
755845
pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
756846
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
847+
ip_adapter_image: Optional[PipelineImageInput] = None,
848+
ip_adapter_image_embeds: Optional[torch.Tensor] = None,
757849
output_type: Optional[str] = "pil",
758850
return_dict: bool = True,
759851
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -843,6 +935,12 @@ def __call__(
843935
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
844936
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
845937
input argument.
938+
ip_adapter_image (`PipelineImageInput`, *optional*):
939+
Optional image input to work with IP Adapters.
940+
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
941+
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
942+
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
943+
`True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
846944
output_type (`str`, *optional*, defaults to `"pil"`):
847945
The output format of the generate image. Choose between
848946
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1040,7 +1138,22 @@ def __call__(
10401138
# SD35 official 8b controlnet does not use encoder_hidden_states
10411139
controlnet_encoder_hidden_states = None
10421140

1043-
# 7. Denoising loop
1141+
# 7. Prepare image embeddings
1142+
if (ip_adapter_image is not None and self.is_ip_adapter_active) or ip_adapter_image_embeds is not None:
1143+
ip_adapter_image_embeds = self.prepare_ip_adapter_image_embeds(
1144+
ip_adapter_image,
1145+
ip_adapter_image_embeds,
1146+
device,
1147+
batch_size * num_images_per_prompt,
1148+
self.do_classifier_free_guidance,
1149+
)
1150+
1151+
if self.joint_attention_kwargs is None:
1152+
self._joint_attention_kwargs = {"ip_adapter_image_embeds": ip_adapter_image_embeds}
1153+
else:
1154+
self._joint_attention_kwargs.update(ip_adapter_image_embeds=ip_adapter_image_embeds)
1155+
1156+
# 8. Denoising loop
10441157
with self.progress_bar(total=num_inference_steps) as progress_bar:
10451158
for i, t in enumerate(timesteps):
10461159
if self.interrupt:

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,8 @@ def __call__(
870870
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
871871
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
872872
input argument.
873-
ip_adapter_image (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
873+
ip_adapter_image (`PipelineImageInput`, *optional*):
874+
Optional image input to work with IP Adapters.
874875
ip_adapter_image_embeds (`torch.Tensor`, *optional*):
875876
Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
876877
emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to

tests/pipelines/controlnet_sd3/test_controlnet_sd3.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,8 @@ def get_dummy_components(
150150
"transformer": transformer,
151151
"vae": vae,
152152
"controlnet": controlnet,
153+
"image_encoder": None,
154+
"feature_extractor": None,
153155
}
154156

155157
def get_dummy_inputs(self, device, seed=0):

0 commit comments

Comments
 (0)