|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
| 15 | +import inspect |
15 | 16 | from typing import Any, Callable, Dict, List, Optional, Union |
16 | 17 |
|
17 | 18 | import numpy as np |
|
28 | 29 | from ...utils.torch_utils import randn_tensor |
29 | 30 | from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
30 | 31 | from . import StableDiffusionPipelineOutput |
31 | | -from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents |
32 | 32 | from .safety_checker import StableDiffusionSafetyChecker |
33 | 33 |
|
34 | 34 |
|
@@ -66,10 +66,23 @@ def preprocess(image): |
66 | 66 | return image |
67 | 67 |
|
68 | 68 |
|
| 69 | +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents |
| 70 | +def retrieve_latents( |
| 71 | + encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
| 72 | +): |
| 73 | + if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": |
| 74 | + return encoder_output.latent_dist.sample(generator) |
| 75 | + elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": |
| 76 | + return encoder_output.latent_dist.mode() |
| 77 | + elif hasattr(encoder_output, "latents"): |
| 78 | + return encoder_output.latents |
| 79 | + else: |
| 80 | + raise AttributeError("Could not access latents of provided encoder_output") |
| 81 | + |
| 82 | + |
69 | 83 | class StableDiffusionInstructPix2PixPipeline( |
70 | 84 | DiffusionPipeline, |
71 | 85 | StableDiffusionMixin, |
72 | | - SDMixin, |
73 | 86 | TextualInversionLoaderMixin, |
74 | 87 | StableDiffusionLoraLoaderMixin, |
75 | 88 | IPAdapterMixin, |
@@ -710,6 +723,51 @@ def prepare_ip_adapter_image_embeds( |
710 | 723 |
|
711 | 724 | return image_embeds |
712 | 725 |
|
| 726 | + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker |
| 727 | + def run_safety_checker(self, image, device, dtype): |
| 728 | + if self.safety_checker is None: |
| 729 | + has_nsfw_concept = None |
| 730 | + else: |
| 731 | + if torch.is_tensor(image): |
| 732 | + feature_extractor_input = self.image_processor.postprocess(image, output_type="pil") |
| 733 | + else: |
| 734 | + feature_extractor_input = self.image_processor.numpy_to_pil(image) |
| 735 | + safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device) |
| 736 | + image, has_nsfw_concept = self.safety_checker( |
| 737 | + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) |
| 738 | + ) |
| 739 | + return image, has_nsfw_concept |
| 740 | + |
| 741 | + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs |
| 742 | + def prepare_extra_step_kwargs(self, generator, eta): |
| 743 | + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature |
| 744 | + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. |
| 745 | + # eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502 |
| 746 | + # and should be between [0, 1] |
| 747 | + |
| 748 | + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| 749 | + extra_step_kwargs = {} |
| 750 | + if accepts_eta: |
| 751 | + extra_step_kwargs["eta"] = eta |
| 752 | + |
| 753 | + # check if the scheduler accepts generator |
| 754 | + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) |
| 755 | + if accepts_generator: |
| 756 | + extra_step_kwargs["generator"] = generator |
| 757 | + return extra_step_kwargs |
| 758 | + |
| 759 | + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents |
| 760 | + def decode_latents(self, latents): |
| 761 | + deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead" |
| 762 | + deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False) |
| 763 | + |
| 764 | + latents = 1 / self.vae.config.scaling_factor * latents |
| 765 | + image = self.vae.decode(latents, return_dict=False)[0] |
| 766 | + image = (image / 2 + 0.5).clamp(0, 1) |
| 767 | + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 |
| 768 | + image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 769 | + return image |
| 770 | + |
713 | 771 | def check_inputs( |
714 | 772 | self, |
715 | 773 | prompt, |
@@ -839,6 +897,21 @@ def prepare_image_latents( |
839 | 897 |
|
840 | 898 | return image_latents |
841 | 899 |
|
| 900 | + @property |
| 901 | + def guidance_scale(self): |
| 902 | + return self._guidance_scale |
| 903 | + |
842 | 904 | @property |
843 | 905 | def image_guidance_scale(self): |
844 | 906 | return self._image_guidance_scale |
| 907 | + |
| 908 | + @property |
| 909 | + def num_timesteps(self): |
| 910 | + return self._num_timesteps |
| 911 | + |
| 912 | + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) |
| 913 | + # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1` |
| 914 | + # corresponds to doing no classifier free guidance. |
| 915 | + @property |
| 916 | + def do_classifier_free_guidance(self): |
| 917 | + return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0 |
0 commit comments