Skip to content

Commit 178e513

Browse files
committed
Enabled cpu offload
1 parent 44e3847 commit 178e513

File tree

2 files changed

+10
-7
lines changed

2 files changed

+10
-7
lines changed

src/diffusers/loaders/ip_adapter.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -378,8 +378,8 @@ def is_ip_adapter_active(self) -> bool:
378378
def load_ip_adapter(
379379
self,
380380
pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
381-
subfolder: str,
382381
weight_name: str = "ip-adapter.safetensors",
382+
subfolder: Optional[str] = None,
383383
image_encoder_folder: Optional[str] = "image_encoder",
384384
**kwargs,
385385
) -> None:
@@ -393,12 +393,12 @@ def load_ip_adapter(
393393
with [`ModelMixin.save_pretrained`].
394394
- A [torch state
395395
dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
396-
subfolder (`str`):
397-
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
398-
list is passed, it should have the same length as `weight_name`.
399396
weight_name (`str`, defaults to "ip-adapter.safetensors"):
400397
The name of the weight file to load. If a list is passed, it should have the same length as
401398
`subfolder`.
399+
subfolder (`str`, *optional*):
400+
The subfolder location of a model file within a larger model repository on the Hub or locally. If a
401+
list is passed, it should have the same length as `weight_name`.
402402
image_encoder_folder (`str`, *optional*, defaults to `image_encoder`):
403403
The subfolder location of the image encoder within a larger model repository on the Hub or locally.
404404
Pass `None` to not load the image encoder. If the image encoder is located in a folder inside

src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
183183
"""
184184

185185
model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->transformer->vae"
186+
_exclude_from_cpu_offload = ["image_encoder"]
186187
_optional_components = ["image_encoder", "feature_extractor"]
187188
_callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
188189

@@ -694,20 +695,22 @@ def interrupt(self):
694695
return self._interrupt
695696

696697
# Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image
697-
def encode_image(self, image: PipelineImageInput) -> torch.Tensor:
698+
def encode_image(self, image: PipelineImageInput, device: torch.device) -> torch.Tensor:
698699
"""Encodes the given image into a feature representation using a pre-trained image encoder.
699700
700701
Args:
701702
image (`PipelineImageInput`):
702703
Input image to be encoded.
704+
device: (`torch.device`):
705+
Torch device.
703706
704707
Returns:
705708
`torch.Tensor`: The encoded image feature representation.
706709
"""
707710
if not isinstance(image, torch.Tensor):
708711
image = self.feature_extractor(image, return_tensors="pt").pixel_values
709712

710-
image = image.to(device=self.device, dtype=self.dtype)
713+
image = image.to(device=device, dtype=self.dtype)
711714

712715
return self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
713716

@@ -744,7 +747,7 @@ def prepare_ip_adapter_image_embeds(
744747
else:
745748
single_image_embeds = ip_adapter_image_embeds
746749
elif ip_adapter_image is not None:
747-
single_image_embeds = self.encode_image(ip_adapter_image)
750+
single_image_embeds = self.encode_image(ip_adapter_image, device)
748751
if do_classifier_free_guidance:
749752
single_negative_image_embeds = torch.zeros_like(single_image_embeds)
750753
else:

0 commit comments

Comments
 (0)