Skip to content

Commit 51fbe6a

Browse files
committed
up
1 parent d456b5d commit 51fbe6a

File tree

2 files changed

+122
-4
lines changed

2 files changed

+122
-4
lines changed

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
from typing import Callable, List, Optional, Union
1617

1718
import PIL.Image
@@ -27,7 +28,6 @@
2728
from ...utils.torch_utils import randn_tensor
2829
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
2930
from . import StableDiffusionPipelineOutput
30-
from .pipeline_stable_diffusion_utils import SDMixin
3131
from .safety_checker import StableDiffusionSafetyChecker
3232

3333

@@ -41,7 +41,7 @@
4141
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4242

4343

44-
class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin, SDMixin):
44+
class StableDiffusionImageVariationPipeline(DiffusionPipeline, StableDiffusionMixin):
4545
r"""
4646
Pipeline to generate image variations from an input image using Stable Diffusion.
4747
@@ -166,6 +166,51 @@ def _encode_image(self, image, device, num_images_per_prompt, do_classifier_free
166166

167167
return image_embeddings
168168

169+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
170+
def run_safety_checker(self, image, device, dtype):
171+
if self.safety_checker is None:
172+
has_nsfw_concept = None
173+
else:
174+
if torch.is_tensor(image):
175+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
176+
else:
177+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
178+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
179+
image, has_nsfw_concept = self.safety_checker(
180+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
181+
)
182+
return image, has_nsfw_concept
183+
184+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
185+
def decode_latents(self, latents):
186+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
187+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
188+
189+
latents = 1 / self.vae.config.scaling_factor * latents
190+
image = self.vae.decode(latents, return_dict=False)[0]
191+
image = (image / 2 + 0.5).clamp(0, 1)
192+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
193+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
194+
return image
195+
196+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
197+
def prepare_extra_step_kwargs(self, generator, eta):
198+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
199+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
200+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
201+
# and should be between [0, 1]
202+
203+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
204+
extra_step_kwargs = {}
205+
if accepts_eta:
206+
extra_step_kwargs["eta"] = eta
207+
208+
# check if the scheduler accepts generator
209+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
210+
if accepts_generator:
211+
extra_step_kwargs["generator"] = generator
212+
return extra_step_kwargs
213+
169214
def check_inputs(self, image, height, width, callback_steps):
170215
if (
171216
not isinstance(image, torch.Tensor)

src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import inspect
1516
from typing import Any, Callable, Dict, List, Optional, Union
1617

1718
import numpy as np
@@ -28,7 +29,6 @@
2829
from ...utils.torch_utils import randn_tensor
2930
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
3031
from . import StableDiffusionPipelineOutput
31-
from .pipeline_stable_diffusion_utils import SDMixin, retrieve_latents
3232
from .safety_checker import StableDiffusionSafetyChecker
3333

3434

@@ -66,10 +66,23 @@ def preprocess(image):
6666
return image
6767

6868

69+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
70+
def retrieve_latents(
71+
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
72+
):
73+
if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
74+
return encoder_output.latent_dist.sample(generator)
75+
elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
76+
return encoder_output.latent_dist.mode()
77+
elif hasattr(encoder_output, "latents"):
78+
return encoder_output.latents
79+
else:
80+
raise AttributeError("Could not access latents of provided encoder_output")
81+
82+
6983
class StableDiffusionInstructPix2PixPipeline(
7084
DiffusionPipeline,
7185
StableDiffusionMixin,
72-
SDMixin,
7386
TextualInversionLoaderMixin,
7487
StableDiffusionLoraLoaderMixin,
7588
IPAdapterMixin,
@@ -710,6 +723,51 @@ def prepare_ip_adapter_image_embeds(
710723

711724
return image_embeds
712725

726+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker
727+
def run_safety_checker(self, image, device, dtype):
728+
if self.safety_checker is None:
729+
has_nsfw_concept = None
730+
else:
731+
if torch.is_tensor(image):
732+
feature_extractor_input = self.image_processor.postprocess(image, output_type="pil")
733+
else:
734+
feature_extractor_input = self.image_processor.numpy_to_pil(image)
735+
safety_checker_input = self.feature_extractor(feature_extractor_input, return_tensors="pt").to(device)
736+
image, has_nsfw_concept = self.safety_checker(
737+
images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
738+
)
739+
return image, has_nsfw_concept
740+
741+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
742+
def prepare_extra_step_kwargs(self, generator, eta):
743+
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
744+
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
745+
# eta corresponds to η in DDIM paper: https://huggingface.co/papers/2010.02502
746+
# and should be between [0, 1]
747+
748+
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
749+
extra_step_kwargs = {}
750+
if accepts_eta:
751+
extra_step_kwargs["eta"] = eta
752+
753+
# check if the scheduler accepts generator
754+
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
755+
if accepts_generator:
756+
extra_step_kwargs["generator"] = generator
757+
return extra_step_kwargs
758+
759+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
760+
def decode_latents(self, latents):
761+
deprecation_message = "The decode_latents method is deprecated and will be removed in 1.0.0. Please use VaeImageProcessor.postprocess(...) instead"
762+
deprecate("decode_latents", "1.0.0", deprecation_message, standard_warn=False)
763+
764+
latents = 1 / self.vae.config.scaling_factor * latents
765+
image = self.vae.decode(latents, return_dict=False)[0]
766+
image = (image / 2 + 0.5).clamp(0, 1)
767+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
768+
image = image.cpu().permute(0, 2, 3, 1).float().numpy()
769+
return image
770+
713771
def check_inputs(
714772
self,
715773
prompt,
@@ -839,6 +897,21 @@ def prepare_image_latents(
839897

840898
return image_latents
841899

900+
@property
901+
def guidance_scale(self):
902+
return self._guidance_scale
903+
842904
@property
843905
def image_guidance_scale(self):
844906
return self._image_guidance_scale
907+
908+
@property
909+
def num_timesteps(self):
910+
return self._num_timesteps
911+
912+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
913+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
914+
# corresponds to doing no classifier free guidance.
915+
@property
916+
def do_classifier_free_guidance(self):
917+
return self.guidance_scale > 1.0 and self.image_guidance_scale >= 1.0

0 commit comments

Comments
 (0)