Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import inspect
from typing import Callable, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union

import PIL.Image
import torch
Expand All @@ -25,7 +25,7 @@
)

from ...image_processor import PipelineImageInput, VaeImageProcessor
from ...loaders import SD3LoraLoaderMixin
from ...loaders import FromSingleFileMixin, SD3LoraLoaderMixin
from ...models.autoencoders import AutoencoderKL
from ...models.transformers import SD3Transformer2DModel
from ...schedulers import FlowMatchEulerDiscreteScheduler
Expand Down Expand Up @@ -149,7 +149,7 @@ def retrieve_timesteps(
return timesteps, num_inference_steps


class StableDiffusion3Img2ImgPipeline(DiffusionPipeline):
class StableDiffusion3Img2ImgPipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingleFileMixin):
r"""
Args:
transformer ([`SD3Transformer2DModel`]):
Expand Down Expand Up @@ -680,6 +680,10 @@ def prepare_latents(self, image, timestep, batch_size, num_images_per_prompt, dt
def guidance_scale(self):
return self._guidance_scale

@property
def joint_attention_kwargs(self):
return self._joint_attention_kwargs

@property
def clip_skip(self):
return self._clip_skip
Expand Down Expand Up @@ -723,6 +727,7 @@ def __call__(
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
clip_skip: Optional[int] = None,
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
Expand Down Expand Up @@ -797,6 +802,10 @@ def __call__(
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
of a plain tuple.
joint_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
callback_on_step_end (`Callable`, *optional*):
A function that calls at the end of each denoising steps during the inference. The function is called
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
Expand Down Expand Up @@ -835,6 +844,7 @@ def __call__(

self._guidance_scale = guidance_scale
self._clip_skip = clip_skip
self._joint_attention_kwargs = joint_attention_kwargs
self._interrupt = False

# 2. Define call parameters
Expand All @@ -847,6 +857,10 @@ def __call__(

device = self._execution_device

lora_scale = (
self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
)

(
prompt_embeds,
negative_prompt_embeds,
Expand All @@ -868,6 +882,7 @@ def __call__(
clip_skip=self.clip_skip,
num_images_per_prompt=num_images_per_prompt,
max_sequence_length=max_sequence_length,
lora_scale=lora_scale,
)

if self.do_classifier_free_guidance:
Expand Down Expand Up @@ -912,6 +927,7 @@ def __call__(
timestep=timestep,
encoder_hidden_states=prompt_embeds,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=self.joint_attention_kwargs,
return_dict=False,
)[0]

Expand Down
109 changes: 107 additions & 2 deletions tests/lora/test_lora_layers_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,29 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import gc
import sys
import unittest

import numpy as np
import torch
from transformers import AutoTokenizer, CLIPTextModelWithProjection, CLIPTokenizer, T5EncoderModel

from diffusers import FlowMatchEulerDiscreteScheduler, SD3Transformer2DModel, StableDiffusion3Pipeline
from diffusers.utils.testing_utils import is_peft_available, require_peft_backend, require_torch_gpu, torch_device
from diffusers import (
FlowMatchEulerDiscreteScheduler,
SD3Transformer2DModel,
StableDiffusion3Img2ImgPipeline,
StableDiffusion3Pipeline,
)
from diffusers.utils import load_image
from diffusers.utils.import_utils import is_accelerate_available
from diffusers.utils.testing_utils import (
is_peft_available,
numpy_cosine_similarity_distance,
require_peft_backend,
require_torch_gpu,
torch_device,
)


if is_peft_available():
Expand All @@ -29,6 +45,10 @@
from utils import PeftLoraLoaderMixinTests # noqa: E402


if is_accelerate_available():
from accelerate.utils import release_memory


@require_peft_backend
class SD3LoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
pipeline_class = StableDiffusion3Pipeline
Expand Down Expand Up @@ -108,3 +128,88 @@ def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(se
@unittest.skip("Not supported in SD3.")
def test_modify_padding_mode(self):
pass


@require_torch_gpu
@require_peft_backend
class LoraSD3IntegrationTests(unittest.TestCase):
pipeline_class = StableDiffusion3Img2ImgPipeline
repo_id = "stabilityai/stable-diffusion-3-medium-diffusers"

def setUp(self):
super().setUp()
gc.collect()
torch.cuda.empty_cache()

def tearDown(self):
super().tearDown()
gc.collect()
torch.cuda.empty_cache()

def get_inputs(self, device, seed=0):
init_image = load_image(
"https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd_controlnet/bird_canny.png"
)
if str(device).startswith("mps"):
generator = torch.manual_seed(seed)
else:
generator = torch.Generator(device="cpu").manual_seed(seed)

return {
"prompt": "corgi",
"num_inference_steps": 2,
"guidance_scale": 5.0,
"output_type": "np",
"generator": generator,
"image": init_image,
}

def test_sd3_img2img_lora(self):
pipe = self.pipeline_class.from_pretrained(self.repo_id, torch_dtype=torch.float16)
pipe.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SahilCarterr why are we using an SDXL LoRA to test SD3?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i think it can be replace by a sd3 lora hf-internal-testing/tiny-sd3-loras i will open a PR for the fix

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think because those LoRAs were meant for models that are much smaller in nature. So, configs won't match.

This should ideally be tested with "zwloong/sd3-lora-training-rank16-v2" as this is an integration test.

pipe.enable_sequential_cpu_offload()

inputs = self.get_inputs(torch_device)

image = pipe(**inputs).images[0]
image_slice = image[0, :10, :10]
expected_slice = np.array(
[
0.47827148,
0.5,
0.71972656,
0.3955078,
0.4194336,
0.69628906,
0.37036133,
0.40820312,
0.6923828,
0.36450195,
0.40429688,
0.6904297,
0.35595703,
0.39257812,
0.68652344,
0.35498047,
0.3984375,
0.68310547,
0.34716797,
0.3996582,
0.6855469,
0.3388672,
0.3959961,
0.6816406,
0.34033203,
0.40429688,
0.6845703,
0.34228516,
0.4086914,
0.6870117,
]
)

max_diff = numpy_cosine_similarity_distance(expected_slice.flatten(), image_slice.flatten())

assert max_diff < 1e-4, f"Outputs are not close enough, got {max_diff}"
pipe.unload_lora_weights()
release_memory(pipe)
Loading