diff --git a/examples/gradio/gradio_example.md b/examples/gradio/gradio_example.md new file mode 100644 index 0000000..c2e42c6 --- /dev/null +++ b/examples/gradio/gradio_example.md @@ -0,0 +1,92 @@ +# Gradio Example + +- Playground to test differential diffusion SD2 pipeline with different masks. + +## Features +- Generate gradient mask on runtime. +- Import the mask image. +- Extract the depth image from the given image with [Intel/dpt-large](https://huggingface.co/Intel/dpt-large) model. +- Brightness, contrast and image transform functionalities for the mask images. +- Generate output with input image, positive and negative prompts, guidance scale, strength, number of inference steps parameters. + +## Required Packages +- transformers +- accelerate +- torch +- diffusers +- gradio + +## Usage +1. Go to ```differential-diffusion``` directory +2. Active the virtual environment if there is one +3. Run ```python examples/gradio/main.py``` command + +## Tabs + +- When changing between tabs, configured depth images are updated on the ```Mask Image```. + +### Gradient Mask + +![gradient_mask_tab_ss](screenshots/gradient_mask_tab.png) + +| Variable | Definition | +| :-------------------: | :-------------------------------------------------- | +| Gradient Mask | Temp gradient mask to use in calculations | +| Gradient Image Width | Define the width of the gradient image | +| Gradient Image Height | Define the height of the gradient image | +| Gradient Strength | Set the white and black ratios | +| Image Brightness | Set the gradient mask image's brightness | +| Image Contrast | Set the gradient mask images's contrast | +| Flip Horizontal | Flip the gradient mask image horizontally | +| To Vertical | Rotate the gradient mask image 90 degrees clockwise | +| Flip Vertical | Flip the gradient mask image vertically | + +- ```Mask Image``` is updated when any variable ,except ```Gradient Mask```, value is changed + +### Image Mask + +![image_mask_tab_ss](screenshots/image_mask_tab.png) + +| Variable | Definition | +| :--------------: | :----------------------------------------------------------- | +| Image Mask | The mask image can be uploaded as file or from the clipboard | +| Image Brightness | Set the input mask image's brightness | +| Image Contrast | Set the input mask images's contrast | +| Flip Horizontal | Flip the input mask image horizontally | +| To Vertical | Rotate the input mask image 90 degrees clockwise | +| Flip Vertical | Flip the input mask image vertically | + +- ```Mask Image``` is updated when any variable value is changed + +### Extract Depth + +![extract_depth_tab_ss](screenshots/extract_depth_tab.png) + +| Variable | Definition | +| :-------------------: | :------------------------------------------------------ | +| Input Image | Original image which is used to extract the depth image | +| Extracted Depth Image | Showing the extracted depth image | +| Image Brightness | Set the extracted depth image's brightness | +| Image Contrast | Set the extracted depth images's contrast | +| Flip Horizontal | Flip the extracted depth image horizontally | +| To Vertical | Rotate the extracted depth image 90 degrees clockwise | +| Flip Vertical | Flip the extracted depth image vertically | + +- ```Mask Image``` is updated when any variable ,except ```Extracted Depth Image```, value is changed + +### Generate + +![generate_tab_ss](screenshots/generated_tab.png) + +| Variable | Definition | +| :-------------: | :---------------------------------------------------- | +| Image | Image for the differential diffusion Img2Img pipeline | +| Guidance Scale | Guidance scale for the pipeline | +| Inference Step | Number of inference steps for the pipeline | +| Strength | Strength variable for the pipeline | +| Positive Prompt | Positive prompt for the pipeline | +| Negative Prompt | Negative prompt for the pipeline | + +- The progress is started with ```Generate``` button. +- The generated output is showed in ```Output Image```. +- ```Mask Image``` has to have image otherwise generation will fail. diff --git a/examples/gradio/main.py b/examples/gradio/main.py new file mode 100644 index 0000000..c7b6a4a --- /dev/null +++ b/examples/gradio/main.py @@ -0,0 +1,116 @@ +from gradio import Blocks, Tab, Row, Column, Image + +from utilities.event_funcs import gradient_calculate, image_edit_change + +from tabs.gradient_tab import GradientTab +from tabs.image_mask_tab import ImageMaskTab +from tabs.generate_tab import GenerateTab +from tabs.extract_depth_tab import ExtractDepthTab + +from importlib.util import find_spec + + +def check_package(package_name: str) -> None: + if find_spec(package_name): + print(f"/_\ {package_name} is found") + else: + print( + f"/_\ {package_name} is not found. Please install the {package_name} with pip" + ) + exit(1) + + +print(" Differential Diffusion Gradio Example ".center(100, "-")) +print("/_\ Checking Packages") + +check_package("diffusers") +check_package("transformers") +check_package("accelerate") +check_package("torch") +check_package("gradio") + +print("/_\ Launching example") + +gradient_tab = GradientTab() +image_mask_tab = ImageMaskTab() +extract_depth_tab = ExtractDepthTab() +generate_tab = GenerateTab() + +with Blocks() as example: + with Row() as main_row: + with Tab("Gradient Mask") as tab_gradient: + gradient_tab.render() + + with Tab("Image Mask") as tab_image_mask: + image_mask_tab.render() + + with Tab("Extract Depth") as tab_extracted_depth: + extract_depth_tab.render() + + with Tab("Generate") as tab_generate: + generate_tab.render() + + with Column(): + mask_image = Image( + value=gradient_calculate(512, 512, 1.0, 1.0, 1.0, False, False, False), + sources=None, + label="Mask Image", + width=512, + height=512, + ) + output_image = Image( + sources=None, label="Output Image", width=512, height=512 + ) + + tab_gradient.select( + gradient_calculate, + inputs=[ + gradient_tab.width_slider, + gradient_tab.height_slider, + gradient_tab.strength_slider, + gradient_tab.image_edit.brightness_slider, + gradient_tab.image_edit.contrast_slider, + gradient_tab.image_edit.flip_horizontal_checkbox, + gradient_tab.image_edit.to_vertical_checkbox, + gradient_tab.image_edit.flip_vertical_checkbox, + ], + outputs=mask_image, + show_progress="hidden", + ) + + tab_image_mask.select( + image_edit_change, + inputs=[ + image_mask_tab.mask_image, + image_mask_tab.image_edit.brightness_slider, + image_mask_tab.image_edit.contrast_slider, + image_mask_tab.image_edit.flip_horizontal_checkbox, + image_mask_tab.image_edit.to_vertical_checkbox, + image_mask_tab.image_edit.flip_vertical_checkbox, + ], + outputs=mask_image, + show_progress="hidden", + ) + + tab_extracted_depth.select( + image_edit_change, + inputs=[ + extract_depth_tab.extracted_depth_image, + extract_depth_tab.image_edit.brightness_slider, + extract_depth_tab.image_edit.contrast_slider, + extract_depth_tab.image_edit.flip_horizontal_checkbox, + extract_depth_tab.image_edit.to_vertical_checkbox, + extract_depth_tab.image_edit.flip_vertical_checkbox, + ], + outputs=mask_image, + show_progress="hidden", + ) + + gradient_tab.attach_event(mask_image) + image_mask_tab.attach_event(mask_image) + extract_depth_tab.attach_event(mask_image) + generate_tab.attach_event(mask_image, output_image) + + +if __name__ == "__main__": + example.launch() diff --git a/examples/gradio/screenshots/extract_depth_tab.png b/examples/gradio/screenshots/extract_depth_tab.png new file mode 100644 index 0000000..a0483ae Binary files /dev/null and b/examples/gradio/screenshots/extract_depth_tab.png differ diff --git a/examples/gradio/screenshots/generated_tab.png b/examples/gradio/screenshots/generated_tab.png new file mode 100644 index 0000000..231a347 Binary files /dev/null and b/examples/gradio/screenshots/generated_tab.png differ diff --git a/examples/gradio/screenshots/gradient_mask_tab.png b/examples/gradio/screenshots/gradient_mask_tab.png new file mode 100644 index 0000000..90f2bb5 Binary files /dev/null and b/examples/gradio/screenshots/gradient_mask_tab.png differ diff --git a/examples/gradio/screenshots/image_mask_tab.png b/examples/gradio/screenshots/image_mask_tab.png new file mode 100644 index 0000000..89a1777 Binary files /dev/null and b/examples/gradio/screenshots/image_mask_tab.png differ diff --git a/examples/gradio/tabs/extract_depth_tab.py b/examples/gradio/tabs/extract_depth_tab.py new file mode 100644 index 0000000..1110651 --- /dev/null +++ b/examples/gradio/tabs/extract_depth_tab.py @@ -0,0 +1,77 @@ +from transformers import pipeline +from gradio import Column, Image, Row +import numpy as np +from PIL import Image as pil_image +from torch.nn.functional import interpolate + +from .image_edit_block import ImageEditBlock + + +class ExtractDepthTab: + def __init__(self) -> None: + self.main_column = Column() + + self.images_row = Row() + + self.wanted_image = Image( + sources=["upload", "clipboard"], + type="pil", + width=512, + height=512, + label="Input Image", + ) + + from os.path import join + + self.extracted_depth_image = Image( + value=pil_image.open(join("assets", "map2.jpg")), + width=512, + height=512, + type="pil", + label="Extracted Depth Image", + sources=None, + ) + + self.pipe = pipeline( + "depth-estimation", + model="Intel/dpt-large", + framework="pt", + torch_dtype="auto", + ) + + self.image_edit = ImageEditBlock() + + def render(self) -> None: + with self.images_row: + self.wanted_image.render() + self.extracted_depth_image.render() + + self.images_row.render() + self.image_edit.render() + + def attach_event(self, output_image) -> None: + def extract_depth(given_image): + outputs = self.pipe(given_image) + predicted_depth = outputs["predicted_depth"] + + prediction = interpolate( + input=predicted_depth.unsqueeze(1), + size=given_image.size[::-1], + mode="bicubic", + align_corners=False, + ) + + output = prediction.squeeze().cpu().numpy() + formatted = (output * 255 / np.max(output)).astype("uint8") + pil_formatted_image = pil_image.fromarray(formatted) + + return [pil_formatted_image, pil_formatted_image] + + self.wanted_image.upload( + extract_depth, + inputs=self.wanted_image, + outputs=[self.extracted_depth_image, output_image], + show_progress="full", + ) + + self.image_edit.attach_event(self.extracted_depth_image, output_image) diff --git a/examples/gradio/tabs/generate_tab.py b/examples/gradio/tabs/generate_tab.py new file mode 100644 index 0000000..c906080 --- /dev/null +++ b/examples/gradio/tabs/generate_tab.py @@ -0,0 +1,140 @@ +from gradio import Column, Row, Textbox, Image, Button, Slider +from utilities.SD2.diff_pipe import StableDiffusionDiffImg2ImgPipeline +from torchvision import transforms +from torch.cuda import is_available +from PIL import Image as pil_image + +device = "cuda" if is_available() else "cpu" + + +def preprocess_image(image): + image = image.convert("RGB") + image = transforms.CenterCrop((image.size[1] // 64 * 64, image.size[0] // 64 * 64))( + image + ) + image = transforms.ToTensor()(image) + image = image * 2 - 1 + image = image.unsqueeze(0).to(device) + return image + + +def preprocess_map(map): + map = map.convert("L") + map = transforms.CenterCrop((map.size[1] // 64 * 64, map.size[0] // 64 * 64))(map) + # convert to tensor + map = transforms.ToTensor()(map) + map = map.to(device) + return map + + +class GenerateTab: + def __init__(self) -> None: + self.config_row = Row() + self.parameter_column = Column() + + from os.path import join + + self.input_image = Image( + value=pil_image.open(join("assets", "input.jpg")), + width=512, + height=512, + show_download_button=False, + sources=["upload", "clipboard"], + interactive=True, + ) + + self.guidance_scale_slider = Slider( + minimum=0, + maximum=20, + value=7, + label="Guidance Scale", + step=0.5, + interactive=True, + ) + + self.number_of_steps_slider = Slider( + minimum=0, + maximum=200, + step=1, + value=100, + label="Inference Step", + interactive=True, + ) + + self.strength_slider = Slider( + minimum=0, maximum=1, value=1, step=0.1, label="Strength", interactive=True + ) + + self.positive_prompt_textbox = Textbox( + value="painting of a mountain landscape with a meadow and a forest, meadow background, anime countryside landscape, anime nature wallpap, anime landscape wallpaper, studio ghibli landscape, anime landscape, mountain behind meadow, anime background art, studio ghibli environment, background of flowery hill, anime beautiful peace scene, forrest background, anime scenery, landscape background, background art, anime scenery concept art", + max_lines=300, + label="Positive Prompt", + ) + + self.negative_prompt_textbox = Textbox( + value="blurry, shadow polaroid photo, scary angry pose, worn decay texture, portrait fashion model, piercing stare, bruised face, demoness", + max_lines=300, + label="Negative Prompt", + ) + + self.generate_button = Button(value="Generate") + + self.pipe = StableDiffusionDiffImg2ImgPipeline.from_pretrained( + "stabilityai/stable-diffusion-2-1-base", + device_map="auto", + ) + + def render(self) -> None: + with self.parameter_column: + self.guidance_scale_slider.render() + self.number_of_steps_slider.render() + self.strength_slider.render() + + with self.config_row: + self.input_image.render() + self.parameter_column.render() + + self.config_row.render() + self.positive_prompt_textbox.render() + self.negative_prompt_textbox.render() + self.generate_button.render() + + def attach_event(self, mask_image, result_image) -> None: + + def generate_image( + map_image, + input_image, + guidance_scale, + inference_steps, + strength, + positive_prompt, + negative_prompt, + ): + processed_mask_image = preprocess_map(pil_image.fromarray(map_image)) + processed_input_image = preprocess_image(pil_image.fromarray(input_image)) + + return self.pipe( + prompt=[positive_prompt], + image=processed_input_image, + num_images_per_prompt=1, + negative_prompt=[negative_prompt], + map=processed_mask_image, + num_inference_steps=inference_steps, + guidance_scale=guidance_scale, + strength=strength, + ).images[0] + + self.generate_button.click( + generate_image, + inputs=[ + mask_image, + self.input_image, + self.guidance_scale_slider, + self.number_of_steps_slider, + self.strength_slider, + self.positive_prompt_textbox, + self.negative_prompt_textbox, + ], + outputs=result_image, + show_progress="full", + ) diff --git a/examples/gradio/tabs/gradient_tab.py b/examples/gradio/tabs/gradient_tab.py new file mode 100644 index 0000000..b625d3c --- /dev/null +++ b/examples/gradio/tabs/gradient_tab.py @@ -0,0 +1,105 @@ +from gradio import Slider, Image, on +from PIL.Image import Image as pil_image + +from utilities.gradient import create_gradient +from utilities.event_funcs import image_enhancement_change, image_transform_change + +from .image_edit_block import ImageEditBlock + + +class GradientTab: + def __init__(self) -> None: + from utilities.event_funcs import gradient_calculate + + self.default_image = Image( + value=gradient_calculate(512, 512, 1.0, 1.0, 1.0, False, False, False), + image_mode="L", + sources=None, + type="pil", + label="Gradient Mask", + interactive=False, + width=256, + height=256, + ) + + self.width_slider = Slider( + minimum=512, + maximum=4096, + value=512, + step=1, + label="Gradient Image Width", + interactive=True, + ) + + self.height_slider = Slider( + minimum=512, + maximum=4096, + value=512, + step=1, + label="Gradient Image Height", + interactive=True, + ) + + self.strength_slider = Slider( + minimum=0, + maximum=20, + value=1, + step=0.5, + label="Gradient Strength", + interactive=True, + ) + + self.image_edit = ImageEditBlock() + + def render(self) -> None: + self.default_image.render() + self.width_slider.render() + self.height_slider.render() + self.strength_slider.render() + self.image_edit.render() + + def attach_event( + self, + output_image, + ) -> None: + @on( + triggers=[ + self.width_slider.release, + self.height_slider.release, + self.strength_slider.release, + self.image_edit.brightness_slider.release, + self.image_edit.contrast_slider.release, + self.image_edit.flip_horizontal_checkbox.select, + self.image_edit.to_vertical_checkbox.select, + self.image_edit.flip_vertical_checkbox.select, + ], + inputs=[ + self.width_slider, + self.height_slider, + self.strength_slider, + self.image_edit.brightness_slider, + self.image_edit.contrast_slider, + self.image_edit.flip_horizontal_checkbox, + self.image_edit.to_vertical_checkbox, + self.image_edit.flip_vertical_checkbox, + ], + outputs=output_image, + show_progress="hidden", + ) + def gradient_calculate( + image_width: int, + image_height: int, + strength: float, + brightness: float, + contrast: float, + is_flip_horizontal: bool, + is_to_vertical: bool, + is_flip_vertical: bool, + ) -> pil_image: + image = create_gradient(image_width, image_height, strength) + image = image_enhancement_change(image, brightness, contrast) + image = image_transform_change( + image, is_flip_horizontal, is_to_vertical, is_flip_vertical + ) + + return image diff --git a/examples/gradio/tabs/image_edit_block.py b/examples/gradio/tabs/image_edit_block.py new file mode 100644 index 0000000..3335801 --- /dev/null +++ b/examples/gradio/tabs/image_edit_block.py @@ -0,0 +1,97 @@ +from gradio import Column, Row, Slider, Checkbox +from PIL.Image import Image as pil_image +from gradio import on +from PIL.ImageEnhance import Brightness, Contrast +from PIL.Image import Transpose + + +class ImageEditBlock: + def __init__(self) -> None: + self.enhancement_column = Column() + + self.brightness_slider = Slider( + minimum=0.0, + maximum=10.0, + value=1.0, + step=0.5, + label="Image Brightness", + interactive=True, + ) + + self.contrast_slider = Slider( + minimum=0.0, + maximum=10.0, + value=1.0, + step=0.5, + label="Image Contrast", + interactive=True, + ) + + self.image_transformation_row = Row() + + self.flip_horizontal_checkbox = Checkbox( + value=False, label="Flip Horizontal", interactive=True + ) + + self.to_vertical_checkbox = Checkbox( + value=False, label="To Vertical", interactive=True + ) + + self.flip_vertical_checkbox = Checkbox( + value=False, label="Flip Vertical", interactive=True + ) + + def render(self) -> None: + with self.enhancement_column: + self.brightness_slider.render() + self.contrast_slider.render() + + with self.image_transformation_row: + self.flip_horizontal_checkbox.render() + self.to_vertical_checkbox.render() + self.flip_vertical_checkbox.render() + + self.enhancement_column.render() + self.image_transformation_row.render() + + def attach_event(self, input_image, output_image) -> None: + @on( + triggers=[ + self.brightness_slider.release, + self.contrast_slider.release, + self.flip_horizontal_checkbox.select, + self.to_vertical_checkbox.select, + self.flip_vertical_checkbox.select, + ], + inputs=[ + input_image, + self.brightness_slider, + self.contrast_slider, + self.flip_horizontal_checkbox, + self.to_vertical_checkbox, + self.flip_vertical_checkbox, + ], + outputs=output_image, + show_progress="hidden", + ) + def image_edit_change( + image: pil_image, + brightness: float, + contrast: float, + is_flip_horizontal: bool, + is_to_vertical: bool, + is_flip_vertical: bool, + ) -> pil_image: + image = Brightness(image).enhance(brightness) + image = Contrast(image).enhance(contrast) + + if is_flip_horizontal: + image = image.transpose(Transpose.FLIP_LEFT_RIGHT) + + if is_to_vertical: + image = image.transpose(Transpose.ROTATE_90) + + if is_flip_vertical: + image = image.transpose(Transpose.FLIP_TOP_BOTTOM) + + return image diff --git a/examples/gradio/tabs/image_mask_tab.py b/examples/gradio/tabs/image_mask_tab.py new file mode 100644 index 0000000..fd1d0b0 --- /dev/null +++ b/examples/gradio/tabs/image_mask_tab.py @@ -0,0 +1,38 @@ +from gradio import Column, Image +from PIL import Image as pil_image +from .image_edit_block import ImageEditBlock + + +class ImageMaskTab: + def __init__(self) -> None: + self.main_column = Column() + + from os.path import join + + self.mask_image = Image( + value=pil_image.open(join("assets", "map2.jpg")), + image_mode="L", + sources=["upload", "clipboard"], + type="pil", + label="Image Mask", + interactive=True, + ) + + self.image_edit = ImageEditBlock() + + def render(self) -> None: + self.mask_image.render() + self.image_edit.render() + + def attach_event(self, output_image) -> None: + def upload_image(image): + return image + + self.mask_image.upload( + upload_image, + inputs=self.mask_image, + outputs=output_image, + show_progress="hidden", + ) + + self.image_edit.attach_event(self.mask_image, output_image) diff --git a/examples/gradio/utilities/SD2/diff_pipe.py b/examples/gradio/utilities/SD2/diff_pipe.py new file mode 100644 index 0000000..3a85343 --- /dev/null +++ b/examples/gradio/utilities/SD2/diff_pipe.py @@ -0,0 +1,751 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Callable, List, Optional, Union + +import numpy as np +import PIL +import torch +import torchvision.transforms +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + PIL_INTERPOLATION, + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + randn_tensor, + replace_example_docstring, +) +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import requests + >>> import torch + >>> from PIL import Image + >>> from io import BytesIO + + >>> from diffusers import StableDiffusionImg2ImgPipeline + + >>> device = "cuda" + >>> model_id_or_path = "runwayml/stable-diffusion-v1-5" + >>> pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_id_or_path, torch_dtype=torch.float16) + >>> pipe = pipe.to(device) + + >>> url = "https://raw.githubusercontent.com/CompVis/stable-diffusion/main/assets/stable-samples/img2img/sketch-mountains-input.jpg" + + >>> response = requests.get(url) + >>> init_image = Image.open(BytesIO(response.content)).convert("RGB") + >>> init_image = init_image.resize((768, 512)) + + >>> prompt = "A fantasy landscape, trending on artstation" + + >>> images = pipe(prompt=prompt, image=init_image, strength=0.75, guidance_scale=7.5).images + >>> images[0].save("fantasy_landscape.png") + ``` +""" + + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = map(lambda x: x - x % 8, (w, h)) # resize to integer multiple of 8 + + image = [np.array(i.resize((w, h), resample=PIL_INTERPOLATION["lanczos"]))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionDiffImg2ImgPipeline(DiffusionPipeline): + r""" + Pipeline for text-guided image to image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.__init__ + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + scheduler: KarrasDiffusionSchedulers, + safety_checker: StableDiffusionSafetyChecker, + feature_extractor: CLIPFeatureExtractor, + requires_safety_checker: bool = True, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + if safety_checker is None and requires_safety_checker: + logger.warning( + f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure" + " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered" + " results in services or applications open to the public. Both the diffusers team and Hugging Face" + " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling" + " it only for use-cases that involve analyzing network behavior or auditing its results. For more" + " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ." + ) + + if safety_checker is not None and feature_extractor is None: + raise ValueError( + "Make sure to define a feature extractor when loading {self.__class__} if you want to use the safety" + " checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead." + ) + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + safety_checker=safety_checker, + feature_extractor=feature_extractor, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + self.register_to_config(requires_safety_checker=requires_safety_checker) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + if self.safety_checker is not None: + cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_model_cpu_offload + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + if self.safety_checker is not None: + _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker + def run_safety_checker(self, image, device, dtype): + if self.safety_checker is not None: + safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device) + image, has_nsfw_concept = self.safety_checker( + images=image, clip_input=safety_checker_input.pixel_values.to(dtype) + ) + else: + has_nsfw_concept = None + return image, has_nsfw_concept + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, prompt, strength, callback_steps, negative_prompt=None, prompt_embeds=None, negative_prompt_embeds=None + ): + if strength < 0 or strength > 1: + raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start:] + + return timesteps, num_inference_steps - t_start + + def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}" + ) + + image = image.to(device=device, dtype=dtype) + + batch_size = batch_size * num_images_per_prompt + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if isinstance(generator, list): + init_latents = [ + self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size) + ] + init_latents = torch.cat(init_latents, dim=0) + else: + init_latents = self.vae.encode(image).latent_dist.sample(generator) + + init_latents = self.vae.config.scaling_factor * init_latents + + if batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] == 0: + # expand init_latents for batch_size + deprecation_message = ( + f"You have passed {batch_size} text prompts (`prompt`), but only {init_latents.shape[0]} initial" + " images (`image`). Initial images are now duplicating to match the number of text prompts. Note" + " that this behavior is deprecated and will be removed in a version 1.0.0. Please make sure to update" + " your script to pass as many initial images as text prompts to suppress this warning." + ) + deprecate("len(prompt) != len(image)", "1.0.0", deprecation_message, standard_warn=False) + additional_image_per_prompt = batch_size // init_latents.shape[0] + init_latents = torch.cat([init_latents] * additional_image_per_prompt, dim=0) + elif batch_size > init_latents.shape[0] and batch_size % init_latents.shape[0] != 0: + raise ValueError( + f"Cannot duplicate `image` of batch size {init_latents.shape[0]} to {batch_size} text prompts." + ) + else: + init_latents = torch.cat([init_latents], dim=0) + + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + latents = init_latents + + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image] = None, + strength: float = 1, + num_inference_steps: Optional[int] = 50, + guidance_scale: Optional[float] = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: Optional[float] = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + map:torch.FloatTensor = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`torch.FloatTensor` or `PIL.Image.Image`): + `Image`, or tensor representing an image batch, that will be used as the starting point for the + process. + strength (`float`, *optional*, defaults to 1): + Repealed in favor of the map. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. This parameter will be modulated by `strength`. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 1. Check inputs. Raise error if not correct + self.check_inputs(prompt, strength, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess image + image = preprocess(image) + + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + + + # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + map = torchvision.transforms.Resize(tuple(s // self.vae_scale_factor for s in image.shape[2:]),antialias=None)(map) + + # 8. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + + # prepartions + original_with_noise = self.prepare_latents( + image, timesteps, batch_size, num_images_per_prompt, prompt_embeds.dtype, device, generator + ) + thresholds = torch.arange(len(timesteps), dtype=map.dtype) / len(timesteps) + thresholds = thresholds.unsqueeze(1).unsqueeze(1).to(device) + masks = map > thresholds + # end diff diff preparations + + with self.progress_bar(total=num_inference_steps) as progress_bar: + + for i, t in enumerate(timesteps): + # diff diff + if i == 0: + latents = original_with_noise[:1] + else: + mask = masks[i].unsqueeze(0) + # cast mask to the same type as latents etc + mask = mask.to(latents.dtype) + mask = mask.unsqueeze(1) # fit shape + latents = original_with_noise[i] * mask + latents * (1 - mask) + # end diff diff + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=prompt_embeds).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + # 9. Post-processing + image = self.decode_latents(latents) + + # 10. Run safety checker + image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype) + #has_nsfw_concept = False + + # 11. Convert to PIL + if output_type == "pil": + image = self.numpy_to_pil(image) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, has_nsfw_concept) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) diff --git a/examples/gradio/utilities/event_funcs.py b/examples/gradio/utilities/event_funcs.py new file mode 100644 index 0000000..ae4c390 --- /dev/null +++ b/examples/gradio/utilities/event_funcs.py @@ -0,0 +1,71 @@ +from PIL.Image import Image as pil_image +from PIL.Image import Transpose +from PIL.ImageEnhance import Brightness, Contrast + +from .gradient import create_gradient + + +def image_enhancement_change( + image: pil_image, brightness: float, contrast: float +) -> pil_image: + image = Brightness(image).enhance(brightness) + image = Contrast(image).enhance(contrast) + + return image + + +def image_transform_change( + image: pil_image, + is_flip_horizontal: bool, + is_to_vertical: bool, + is_flip_vertical: bool, +) -> pil_image: + if is_to_vertical: + image = image.transpose(Transpose.ROTATE_90) + + if is_flip_vertical: + image = image.transpose(Transpose.FLIP_TOP_BOTTOM) + + if is_flip_horizontal: + image = image.transpose(Transpose.FLIP_LEFT_RIGHT) + + return image + + +def image_edit_change( + image: pil_image, + brightness: float, + contrast: float, + is_flip_horizontal: bool, + is_to_vertical: bool, + is_flip_vertical: bool, +) -> pil_image: + image = image_enhancement_change(image, brightness, contrast) + image = image_transform_change( + image, is_flip_horizontal, is_to_vertical, is_flip_vertical + ) + + return image + + +def gradient_calculate( + image_width: int, + image_height: int, + strength: float, + brightness: float, + contrast: float, + is_flip_horizontal: bool, + is_to_vertical: bool, + is_flip_vertical: bool, +) -> pil_image: + image = create_gradient(image_width, image_height, strength) + image = image_edit_change( + image, + brightness, + contrast, + is_flip_horizontal, + is_to_vertical, + is_flip_vertical, + ) + + return image diff --git a/examples/gradio/utilities/gradient.py b/examples/gradio/utilities/gradient.py new file mode 100644 index 0000000..17a8101 --- /dev/null +++ b/examples/gradio/utilities/gradient.py @@ -0,0 +1,39 @@ +from PIL import Image +from PIL.Image import Image as pil_image +import numpy as np + +# linear interpolation +def l_interp(start: list[int], end: list[int], alpha: float) -> list[int]: + return [ + min(255, max(0, int(start[index] + alpha * (end[index] - start[index])))) + for index in range(3) + ] + + +def create_gradient( + image_width: int, + image_height: int, + strength: float, +) -> pil_image: + start_pixel = (255, 255, 255) # white + end_pixel = (0, 0, 0) # black + + row_pixels = [] + + for width in range(image_width): + current_pixel = l_interp( + start_pixel, end_pixel, ((float(width) / float(image_width)) * strength) + ) + + row_pixels.append(current_pixel) + + image_pixels = [] + image_pixels.extend(row_pixels * image_height) + image_array = np.array(image_pixels, dtype=np.uint8) + image_array = np.reshape(image_array, [image_height, image_width, 3]) + + image = Image.fromarray(image_array) + + image.convert("L") + + return image