diff --git a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py index 96d53663b867..e7345ad96624 100644 --- a/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py @@ -13,7 +13,7 @@ # limitations under the License. import inspect -from typing import Callable, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import PIL.Image import torch @@ -25,7 +25,7 @@ ) from ...image_processor import PipelineImageInput, VaeImageProcessor -from ...loaders import SD3LoraLoaderMixin +from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin from ...models.autoencoders import AutoencoderKL from ...models.transformers import SD3Transformer2DModel from ...schedulers import FlowMatchEulerDiscreteScheduler @@ -149,7 +149,7 @@ def retrieve_timesteps( return timesteps, num_inference_steps -class StableDiffusion3Img2ImgPipeline(DiffusionPipeline): +class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin): r""" Args: transformer ([`SD3Transformer2DModel`]): @@ -680,6 +680,10 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt def guidance_scale(self): return self._guidance_scale + @property + def joint_attention_kwargs(self): + return self._joint_attention_kwargs + @property def clip_skip(self): return self._clip_skip @@ -723,6 +727,7 @@ def __call__( negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, + joint_attention_kwargs: Optional[Dict[str, Any]] = None, clip_skip: Optional[int] = None, callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, callback_on_step_end_tensor_inputs: List[str] = ["latents"], @@ -797,6 +802,10 @@ def __call__( return_dict (`bool`, *optional*, defaults to `True`): Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead of a plain tuple. + joint_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). callback_on_step_end (`Callable`, *optional*): A function that calls at the end of each denoising steps during the inference. The function is called with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, @@ -835,6 +844,7 @@ def __call__( self._guidance_scale = guidance_scale self._clip_skip = clip_skip + self._joint_attention_kwargs = joint_attention_kwargs self._interrupt = False # 2. Define call parameters @@ -847,6 +857,10 @@ def __call__( device = self._execution_device + lora_scale = ( + self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None + ) + ( prompt_embeds, negative_prompt_embeds, @@ -868,6 +882,7 @@ def __call__( clip_skip=self.clip_skip, num_images_per_prompt=num_images_per_prompt, max_sequence_length=max_sequence_length, + lora_scale=lora_scale, ) if self.do_classifier_free_guidance: @@ -912,6 +927,7 @@ def __call__( timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, + joint_attention_kwargs=self.joint_attention_kwargs, return_dict=False, )[0] diff --git a/tests/lora/test_lora_layers_sd3.py b/tests/lora/test_lora_layers_sd3.py index 8f61c95c2fc8..78d4b786d21b 100644 --- a/tests/lora/test_lora_layers_sd3.py +++ b/tests/lora/test_lora_layers_sd3.py @@ -12,13 +12,29 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import gc import sys import unittest +import numpy as np +import torch from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel -from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline -from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device +from diffusers import ( + FlowMatchEulerDiscreteScheduler, + SD3Transformer2DModel, + StableDiffusion3Img2ImgPipeline, + StableDiffusion3Pipeline, +) +from diffusers.utils import load_image +from diffusers.utils.import_utils import is_accelerate_available +from diffusers.utils.testing_utils import ( + is_peft_available, + numpy_cosine_similarity_distance, + require_peft_backend, + require_torch_gpu, + torch_device, +) if is_peft_available(): @@ -29,6 +45,10 @@ from utils import PeftLoraLoaderMixinTests # noqa: E402 +if is_accelerate_available(): + from accelerate.utils import release_memory + + @require_peft_backend class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests): pipeline_class = StableDiffusion3Pipeline @@ -108,3 +128,88 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se @unittest.skip("Not supported in SD3.") def test_modify_padding_mode(self): pass + + +@require_torch_gpu +@require_peft_backend +class LoraSD3IntegrationTests(unittest.TestCase): + pipeline_class = StableDiffusion3Img2ImgPipeline + repo_id = "stabilityai/stable-diffusion-3-medium-diffusers" + + def setUp(self): + super().setUp() + gc.collect() + torch.cuda.empty_cache() + + def tearDown(self): + super().tearDown() + gc.collect() + torch.cuda.empty_cache() + + def get_inputs(self, device, seed=0): + init_image = load_image( + "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png" + ) + if str(device).startswith("mps"): + generator = torch.manual_seed(seed) + else: + generator = torch.Generator(device="cpu").manual_seed(seed) + + return { + "prompt": "corgi", + "num_inference_steps": 2, + "guidance_scale": 5.0, + "output_type": "np", + "generator": generator, + "image": init_image, + } + + def test_sd3_img2img_lora(self): + pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16) + pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors") + pipe.enable_sequential_cpu_offload() + + inputs = self.get_inputs(torch_device) + + image = pipe(**inputs).images[0] + image_slice = image[0, :10, :10] + expected_slice = np.array( + [ + 0.47827148, + 0.5, + 0.71972656, + 0.3955078, + 0.4194336, + 0.69628906, + 0.37036133, + 0.40820312, + 0.6923828, + 0.36450195, + 0.40429688, + 0.6904297, + 0.35595703, + 0.39257812, + 0.68652344, + 0.35498047, + 0.3984375, + 0.68310547, + 0.34716797, + 0.3996582, + 0.6855469, + 0.3388672, + 0.3959961, + 0.6816406, + 0.34033203, + 0.40429688, + 0.6845703, + 0.34228516, + 0.4086914, + 0.6870117, + ] + ) + + max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten()) + + assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}" + pipe.unload_lora_weights() + release_memory(pipe)