@@ -183,6 +183,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
183183    """ 
184184
185185    model_cpu_offload_seq  =  "text_encoder->text_encoder_2->text_encoder_3->transformer->vae" 
186+     _exclude_from_cpu_offload  =  ["image_encoder" ]
186187    _optional_components  =  ["image_encoder" , "feature_extractor" ]
187188    _callback_tensor_inputs  =  ["latents" , "prompt_embeds" , "negative_prompt_embeds" , "negative_pooled_prompt_embeds" ]
188189
@@ -694,20 +695,22 @@ def interrupt(self):
694695        return  self ._interrupt 
695696
696697    # Adapted from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.encode_image 
697-     def  encode_image (self , image : PipelineImageInput ) ->  torch .Tensor :
698+     def  encode_image (self , image : PipelineImageInput ,  device :  torch . device ) ->  torch .Tensor :
698699        """Encodes the given image into a feature representation using a pre-trained image encoder. 
699700
700701        Args: 
701702            image (`PipelineImageInput`): 
702703                Input image to be encoded. 
704+             device: (`torch.device`): 
705+                 Torch device. 
703706
704707        Returns: 
705708            `torch.Tensor`: The encoded image feature representation. 
706709        """ 
707710        if  not  isinstance (image , torch .Tensor ):
708711            image  =  self .feature_extractor (image , return_tensors = "pt" ).pixel_values 
709712
710-         image  =  image .to (device = self . device , dtype = self .dtype )
713+         image  =  image .to (device = device , dtype = self .dtype )
711714
712715        return  self .image_encoder (image , output_hidden_states = True ).hidden_states [- 2 ]
713716
@@ -744,7 +747,7 @@ def prepare_ip_adapter_image_embeds(
744747            else :
745748                single_image_embeds  =  ip_adapter_image_embeds 
746749        elif  ip_adapter_image  is  not None :
747-             single_image_embeds  =  self .encode_image (ip_adapter_image )
750+             single_image_embeds  =  self .encode_image (ip_adapter_image ,  device )
748751            if  do_classifier_free_guidance :
749752                single_negative_image_embeds  =  torch .zeros_like (single_image_embeds )
750753        else :
0 commit comments