17
17
18
18
import torch
19
19
from transformers import (
20
+ BaseImageProcessor ,
20
21
CLIPTextModelWithProjection ,
21
22
CLIPTokenizer ,
23
+ PreTrainedModel ,
22
24
T5EncoderModel ,
23
25
T5TokenizerFast ,
24
26
)
25
27
26
28
from ...image_processor import PipelineImageInput , VaeImageProcessor
27
- from ...loaders import FromSingleFileMixin , SD3LoraLoaderMixin
29
+ from ...loaders import FromSingleFileMixin , SD3IPAdapterMixin , SD3LoraLoaderMixin
28
30
from ...models .autoencoders import AutoencoderKL
29
31
from ...models .controlnets .controlnet_sd3 import SD3ControlNetModel , SD3MultiControlNetModel
30
32
from ...models .transformers import SD3Transformer2DModel
@@ -138,7 +140,9 @@ def retrieve_timesteps(
138
140
return timesteps , num_inference_steps
139
141
140
142
141
- class StableDiffusion3ControlNetPipeline (DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin ):
143
+ class StableDiffusion3ControlNetPipeline (
144
+ DiffusionPipeline , SD3LoraLoaderMixin , FromSingleFileMixin , SD3IPAdapterMixin
145
+ ):
142
146
r"""
143
147
Args:
144
148
transformer ([`SD3Transformer2DModel`]):
@@ -174,10 +178,14 @@ class StableDiffusion3ControlNetPipeline(DiffusionPipeline, SD3LoraLoaderMixin,
174
178
Provides additional conditioning to the `unet` during the denoising process. If you set multiple
175
179
ControlNets as a list, the outputs from each ControlNet are added together to create one combined
176
180
additional conditioning.
181
+ image_encoder (`PreTrainedModel`, *optional*):
182
+ Pre-trained Vision Model for IP Adapter.
183
+ feature_extractor (`BaseImageProcessor`, *optional*):
184
+ Image processor for IP Adapter.
177
185
"""
178
186
179
- model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
180
- _optional_components = []
187
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder-> transformer->vae"
188
+ _optional_components = ["image_encoder" , "feature_extractor" ]
181
189
_callback_tensor_inputs = ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
182
190
183
191
def __init__ (
@@ -194,6 +202,8 @@ def __init__(
194
202
controlnet : Union [
195
203
SD3ControlNetModel , List [SD3ControlNetModel ], Tuple [SD3ControlNetModel ], SD3MultiControlNetModel
196
204
],
205
+ image_encoder : PreTrainedModel = None ,
206
+ feature_extractor : BaseImageProcessor = None ,
197
207
):
198
208
super ().__init__ ()
199
209
if isinstance (controlnet , (list , tuple )):
@@ -223,6 +233,8 @@ def __init__(
223
233
transformer = transformer ,
224
234
scheduler = scheduler ,
225
235
controlnet = controlnet ,
236
+ image_encoder = image_encoder ,
237
+ feature_extractor = feature_extractor ,
226
238
)
227
239
self .vae_scale_factor = (
228
240
2 ** (len (self .vae .config .block_out_channels ) - 1 ) if hasattr (self , "vae" ) and self .vae is not None else 8
@@ -727,6 +739,84 @@ def num_timesteps(self):
727
739
def interrupt (self ):
728
740
return self ._interrupt
729
741
742
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.encode_image
743
+ def encode_image (self , image : PipelineImageInput , device : torch .device ) -> torch .Tensor :
744
+ """Encodes the given image into a feature representation using a pre-trained image encoder.
745
+
746
+ Args:
747
+ image (`PipelineImageInput`):
748
+ Input image to be encoded.
749
+ device: (`torch.device`):
750
+ Torch device.
751
+
752
+ Returns:
753
+ `torch.Tensor`: The encoded image feature representation.
754
+ """
755
+ if not isinstance (image , torch .Tensor ):
756
+ image = self .feature_extractor (image , return_tensors = "pt" ).pixel_values
757
+
758
+ image = image .to (device = device , dtype = self .dtype )
759
+
760
+ return self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
761
+
762
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.prepare_ip_adapter_image_embeds
763
+ def prepare_ip_adapter_image_embeds (
764
+ self ,
765
+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
766
+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
767
+ device : Optional [torch .device ] = None ,
768
+ num_images_per_prompt : int = 1 ,
769
+ do_classifier_free_guidance : bool = True ,
770
+ ) -> torch .Tensor :
771
+ """Prepares image embeddings for use in the IP-Adapter.
772
+
773
+ Either `ip_adapter_image` or `ip_adapter_image_embeds` must be passed.
774
+
775
+ Args:
776
+ ip_adapter_image (`PipelineImageInput`, *optional*):
777
+ The input image to extract features from for IP-Adapter.
778
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
779
+ Precomputed image embeddings.
780
+ device: (`torch.device`, *optional*):
781
+ Torch device.
782
+ num_images_per_prompt (`int`, defaults to 1):
783
+ Number of images that should be generated per prompt.
784
+ do_classifier_free_guidance (`bool`, defaults to True):
785
+ Whether to use classifier free guidance or not.
786
+ """
787
+ device = device or self ._execution_device
788
+
789
+ if ip_adapter_image_embeds is not None :
790
+ if do_classifier_free_guidance :
791
+ single_negative_image_embeds , single_image_embeds = ip_adapter_image_embeds .chunk (2 )
792
+ else :
793
+ single_image_embeds = ip_adapter_image_embeds
794
+ elif ip_adapter_image is not None :
795
+ single_image_embeds = self .encode_image (ip_adapter_image , device )
796
+ if do_classifier_free_guidance :
797
+ single_negative_image_embeds = torch .zeros_like (single_image_embeds )
798
+ else :
799
+ raise ValueError ("Neither `ip_adapter_image_embeds` or `ip_adapter_image_embeds` were provided." )
800
+
801
+ image_embeds = torch .cat ([single_image_embeds ] * num_images_per_prompt , dim = 0 )
802
+
803
+ if do_classifier_free_guidance :
804
+ negative_image_embeds = torch .cat ([single_negative_image_embeds ] * num_images_per_prompt , dim = 0 )
805
+ image_embeds = torch .cat ([negative_image_embeds , image_embeds ], dim = 0 )
806
+
807
+ return image_embeds .to (device = device )
808
+
809
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline.enable_sequential_cpu_offload
810
+ def enable_sequential_cpu_offload (self , * args , ** kwargs ):
811
+ if self .image_encoder is not None and "image_encoder" not in self ._exclude_from_cpu_offload :
812
+ logger .warning (
813
+ "`pipe.enable_sequential_cpu_offload()` might fail for `image_encoder` if it uses "
814
+ "`torch.nn.MultiheadAttention`. You can exclude `image_encoder` from CPU offloading by calling "
815
+ "`pipe._exclude_from_cpu_offload.append('image_encoder')` before `pipe.enable_sequential_cpu_offload()`."
816
+ )
817
+
818
+ super ().enable_sequential_cpu_offload (* args , ** kwargs )
819
+
730
820
@torch .no_grad ()
731
821
@replace_example_docstring (EXAMPLE_DOC_STRING )
732
822
def __call__ (
@@ -754,6 +844,8 @@ def __call__(
754
844
negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
755
845
pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
756
846
negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
847
+ ip_adapter_image : Optional [PipelineImageInput ] = None ,
848
+ ip_adapter_image_embeds : Optional [torch .Tensor ] = None ,
757
849
output_type : Optional [str ] = "pil" ,
758
850
return_dict : bool = True ,
759
851
joint_attention_kwargs : Optional [Dict [str , Any ]] = None ,
@@ -843,6 +935,12 @@ def __call__(
843
935
Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
844
936
weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
845
937
input argument.
938
+ ip_adapter_image (`PipelineImageInput`, *optional*):
939
+ Optional image input to work with IP Adapters.
940
+ ip_adapter_image_embeds (`torch.Tensor`, *optional*):
941
+ Pre-generated image embeddings for IP-Adapter. Should be a tensor of shape `(batch_size, num_images,
942
+ emb_dim)`. It should contain the negative image embedding if `do_classifier_free_guidance` is set to
943
+ `True`. If not provided, embeddings are computed from the `ip_adapter_image` input argument.
846
944
output_type (`str`, *optional*, defaults to `"pil"`):
847
945
The output format of the generate image. Choose between
848
946
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -1040,7 +1138,22 @@ def __call__(
1040
1138
# SD35 official 8b controlnet does not use encoder_hidden_states
1041
1139
controlnet_encoder_hidden_states = None
1042
1140
1043
- # 7. Denoising loop
1141
+ # 7. Prepare image embeddings
1142
+ if (ip_adapter_image is not None and self .is_ip_adapter_active ) or ip_adapter_image_embeds is not None :
1143
+ ip_adapter_image_embeds = self .prepare_ip_adapter_image_embeds (
1144
+ ip_adapter_image ,
1145
+ ip_adapter_image_embeds ,
1146
+ device ,
1147
+ batch_size * num_images_per_prompt ,
1148
+ self .do_classifier_free_guidance ,
1149
+ )
1150
+
1151
+ if self .joint_attention_kwargs is None :
1152
+ self ._joint_attention_kwargs = {"ip_adapter_image_embeds" : ip_adapter_image_embeds }
1153
+ else :
1154
+ self ._joint_attention_kwargs .update (ip_adapter_image_embeds = ip_adapter_image_embeds )
1155
+
1156
+ # 8. Denoising loop
1044
1157
with self .progress_bar (total = num_inference_steps ) as progress_bar :
1045
1158
for i , t in enumerate (timesteps ):
1046
1159
if self .interrupt :
0 commit comments