@@ -132,6 +132,7 @@ def __init__(
132132
133133 self .watermark = StableDiffusionXLWatermarker ()
134134
135+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_slicing
135136 def enable_vae_slicing (self ):
136137 r"""
137138 Enable sliced VAE decoding.
@@ -141,13 +142,15 @@ def enable_vae_slicing(self):
141142 """
142143 self .vae .enable_slicing ()
143144
145+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_slicing
144146 def disable_vae_slicing (self ):
145147 r"""
146148 Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
147149 computing decoding in one step.
148150 """
149151 self .vae .disable_slicing ()
150152
153+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_vae_tiling
151154 def enable_vae_tiling (self ):
152155 r"""
153156 Enable tiled VAE decoding.
@@ -157,6 +160,7 @@ def enable_vae_tiling(self):
157160 """
158161 self .vae .enable_tiling ()
159162
163+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_vae_tiling
160164 def disable_vae_tiling (self ):
161165 r"""
162166 Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
@@ -217,6 +221,7 @@ def enable_model_cpu_offload(self, gpu_id=0):
217221 self .final_offload_hook = hook
218222
219223 @property
224+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
220225 def _execution_device (self ):
221226 r"""
222227 Returns the device on which the pipeline's models will be executed. After calling
@@ -237,12 +242,14 @@ def _execution_device(self):
237242 def encode_prompt (
238243 self ,
239244 prompt ,
240- device ,
241- num_images_per_prompt ,
242- do_classifier_free_guidance ,
245+ device : Optional [ torch . device ] = None ,
246+ num_images_per_prompt : int = 1 ,
247+ do_classifier_free_guidance : bool = True ,
243248 negative_prompt = None ,
244249 prompt_embeds : Optional [torch .FloatTensor ] = None ,
245250 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
251+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
252+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
246253 lora_scale : Optional [float ] = None ,
247254 ):
248255 r"""
@@ -268,9 +275,18 @@ def encode_prompt(
268275 Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
269276 weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
270277 argument.
278+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
279+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
280+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
281+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
282+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
283+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
284+ input argument.
271285 lora_scale (`float`, *optional*):
272286 A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
273287 """
288+ device = device or self ._execution_device
289+
274290 # set lora scale so that monkey patched LoRA
275291 # function of text encoder can correctly access it
276292 if lora_scale is not None and isinstance (self , LoraLoaderMixin ):
@@ -399,6 +415,7 @@ def encode_prompt(
399415
400416 negative_prompt_embeds = torch .concat (negative_prompt_embeds_list , dim = - 1 )
401417
418+ bs_embed = pooled_prompt_embeds .shape [0 ]
402419 pooled_prompt_embeds = pooled_prompt_embeds .repeat (1 , num_images_per_prompt ).view (
403420 bs_embed * num_images_per_prompt , - 1
404421 )
@@ -408,20 +425,7 @@ def encode_prompt(
408425
409426 return prompt_embeds , negative_prompt_embeds , pooled_prompt_embeds , negative_pooled_prompt_embeds
410427
411- def run_safety_checker (self , image , device , dtype ):
412- if self .safety_checker is None :
413- has_nsfw_concept = None
414- else :
415- if torch .is_tensor (image ):
416- feature_extractor_input = self .image_processor .postprocess (image , output_type = "pil" )
417- else :
418- feature_extractor_input = self .image_processor .numpy_to_pil (image )
419- safety_checker_input = self .feature_extractor (feature_extractor_input , return_tensors = "pt" ).to (device )
420- image , has_nsfw_concept = self .safety_checker (
421- images = image , clip_input = safety_checker_input .pixel_values .to (dtype )
422- )
423- return image , has_nsfw_concept
424-
428+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
425429 def prepare_extra_step_kwargs (self , generator , eta ):
426430 # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
427431 # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -448,6 +452,8 @@ def check_inputs(
448452 negative_prompt = None ,
449453 prompt_embeds = None ,
450454 negative_prompt_embeds = None ,
455+ pooled_prompt_embeds = None ,
456+ negative_pooled_prompt_embeds = None ,
451457 ):
452458 if height % 8 != 0 or width % 8 != 0 :
453459 raise ValueError (f"`height` and `width` have to be divisible by 8 but are { height } and { width } ." )
@@ -486,6 +492,17 @@ def check_inputs(
486492 f" { negative_prompt_embeds .shape } ."
487493 )
488494
495+ if prompt_embeds is not None and pooled_prompt_embeds is None :
496+ raise ValueError (
497+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
498+ )
499+
500+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None :
501+ raise ValueError (
502+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
503+ )
504+
505+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
489506 def prepare_latents (self , batch_size , num_channels_latents , height , width , dtype , device , generator , latents = None ):
490507 shape = (batch_size , num_channels_latents , height // self .vae_scale_factor , width // self .vae_scale_factor )
491508 if isinstance (generator , list ) and len (generator ) != batch_size :
@@ -535,6 +552,8 @@ def __call__(
535552 latents : Optional [torch .FloatTensor ] = None ,
536553 prompt_embeds : Optional [torch .FloatTensor ] = None ,
537554 negative_prompt_embeds : Optional [torch .FloatTensor ] = None ,
555+ pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
556+ negative_pooled_prompt_embeds : Optional [torch .FloatTensor ] = None ,
538557 output_type : Optional [str ] = "pil" ,
539558 return_dict : bool = True ,
540559 callback : Optional [Callable [[int , int , torch .FloatTensor ], None ]] = None ,
@@ -588,6 +607,13 @@ def __call__(
588607 Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
589608 weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
590609 argument.
610+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
611+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
612+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
613+ negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
614+ Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
615+ weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt`
616+ input argument.
591617 output_type (`str`, *optional*, defaults to `"pil"`):
592618 The output format of the generate image. Choose between
593619 [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -634,7 +660,15 @@ def __call__(
634660
635661 # 1. Check inputs. Raise error if not correct
636662 self .check_inputs (
637- prompt , height , width , callback_steps , negative_prompt , prompt_embeds , negative_prompt_embeds
663+ prompt ,
664+ height ,
665+ width ,
666+ callback_steps ,
667+ negative_prompt ,
668+ prompt_embeds ,
669+ negative_prompt_embeds ,
670+ pooled_prompt_embeds ,
671+ negative_pooled_prompt_embeds ,
638672 )
639673
640674 # 2. Define call parameters
@@ -669,6 +703,8 @@ def __call__(
669703 negative_prompt ,
670704 prompt_embeds = prompt_embeds ,
671705 negative_prompt_embeds = negative_prompt_embeds ,
706+ pooled_prompt_embeds = pooled_prompt_embeds ,
707+ negative_pooled_prompt_embeds = negative_pooled_prompt_embeds ,
672708 lora_scale = text_encoder_lora_scale ,
673709 )
674710
@@ -765,27 +801,19 @@ def __call__(
765801 latents = latents .float ()
766802
767803 if not output_type == "latent" :
768- # CHECK there is problem here (PVP)
769804 image = self .vae .decode (latents / self .vae .config .scaling_factor , return_dict = False )[0 ]
770- has_nsfw_concept = None
771805 else :
772806 image = latents
773- has_nsfw_concept = None
774- return StableDiffusionXLPipelineOutput (images = image , nsfw_content_detected = None )
775-
776- if has_nsfw_concept is None :
777- do_denormalize = [True ] * image .shape [0 ]
778- else :
779- do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept ]
807+ return StableDiffusionXLPipelineOutput (images = image )
780808
781809 image = self .watermark .apply_watermark (image )
782- image = self .image_processor .postprocess (image , output_type = output_type , do_denormalize = do_denormalize )
810+ image = self .image_processor .postprocess (image , output_type = output_type )
783811
784812 # Offload last model to CPU
785813 if hasattr (self , "final_offload_hook" ) and self .final_offload_hook is not None :
786814 self .final_offload_hook .offload ()
787815
788816 if not return_dict :
789- return (image , has_nsfw_concept )
817+ return (image ,)
790818
791- return StableDiffusionXLPipelineOutput (images = image , nsfw_content_detected = has_nsfw_concept )
819+ return StableDiffusionXLPipelineOutput (images = image )
0 commit comments