Skip to content

Commit f87b7d2

Browse files
authored
Merge branch 'main' into RMS
2 parents 1f7ee3f + b0c8973 commit f87b7d2

File tree

3 files changed

+33
-4
lines changed

3 files changed

+33
-4
lines changed

examples/dreambooth/train_dreambooth_lora_sana.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def log_validation(
158158
f"Running validation... \n Generating {args.num_validation_images} images with prompt:"
159159
f" {args.validation_prompt}."
160160
)
161+
if args.enable_vae_tiling:
162+
pipeline.vae.enable_tiling(tile_sample_min_height=1024, tile_sample_stride_width=1024)
163+
161164
pipeline.text_encoder = pipeline.text_encoder.to(torch.bfloat16)
162165
pipeline = pipeline.to(accelerator.device)
163166
pipeline.set_progress_bar_config(disable=True)
@@ -597,6 +600,7 @@ def parse_args(input_args=None):
597600
help="Whether to offload the VAE and the text encoder to CPU when they are not used.",
598601
)
599602
parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
603+
parser.add_argument("--enable_vae_tiling", action="store_true", help="Enabla vae tiling in log validation")
600604

601605
if input_args is not None:
602606
args = parser.parse_args(input_args)

src/diffusers/pipelines/mochi/pipeline_mochi.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121

2222
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
2323
from ...loaders import Mochi1LoraLoaderMixin
24-
from ...models.autoencoders import AutoencoderKL
24+
from ...models.autoencoders import AutoencoderKLMochi
2525
from ...models.transformers import MochiTransformer3DModel
2626
from ...schedulers import FlowMatchEulerDiscreteScheduler
2727
from ...utils import (
@@ -151,8 +151,8 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
151151
Conditional Transformer architecture to denoise the encoded video latents.
152152
scheduler ([`FlowMatchEulerDiscreteScheduler`]):
153153
A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
154-
vae ([`AutoencoderKL`]):
155-
Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
154+
vae ([`AutoencoderKLMochi`]):
155+
Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
156156
text_encoder ([`T5EncoderModel`]):
157157
[T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
158158
the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
@@ -171,7 +171,7 @@ class MochiPipeline(DiffusionPipeline, Mochi1LoraLoaderMixin):
171171
def __init__(
172172
self,
173173
scheduler: FlowMatchEulerDiscreteScheduler,
174-
vae: AutoencoderKL,
174+
vae: AutoencoderKLMochi,
175175
text_encoder: T5EncoderModel,
176176
tokenizer: T5TokenizerFast,
177177
transformer: MochiTransformer3DModel,

tests/quantization/bnb/test_mixed_int8.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020
import pytest
21+
from huggingface_hub import hf_hub_download
2122

2223
from diffusers import BitsAndBytesConfig, DiffusionPipeline, FluxTransformer2DModel, SD3Transformer2DModel, logging
2324
from diffusers.utils import is_accelerate_version
@@ -30,6 +31,7 @@
3031
numpy_cosine_similarity_distance,
3132
require_accelerate,
3233
require_bitsandbytes_version_greater,
34+
require_peft_version_greater,
3335
require_torch,
3436
require_torch_gpu,
3537
require_transformers_version_greater,
@@ -509,6 +511,29 @@ def test_quality(self):
509511
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
510512
self.assertTrue(max_diff < 1e-3)
511513

514+
@require_peft_version_greater("0.14.0")
515+
def test_lora_loading(self):
516+
self.pipeline_8bit.load_lora_weights(
517+
hf_hub_download("ByteDance/Hyper-SD", "Hyper-FLUX.1-dev-8steps-lora.safetensors"), adapter_name="hyper-sd"
518+
)
519+
self.pipeline_8bit.set_adapters("hyper-sd", adapter_weights=0.125)
520+
521+
output = self.pipeline_8bit(
522+
prompt=self.prompt,
523+
height=256,
524+
width=256,
525+
max_sequence_length=64,
526+
output_type="np",
527+
num_inference_steps=8,
528+
generator=torch.manual_seed(42),
529+
).images
530+
out_slice = output[0, -3:, -3:, -1].flatten()
531+
532+
expected_slice = np.array([0.3916, 0.3916, 0.3887, 0.4243, 0.4155, 0.4233, 0.4570, 0.4531, 0.4248])
533+
534+
max_diff = numpy_cosine_similarity_distance(expected_slice, out_slice)
535+
self.assertTrue(max_diff < 1e-3)
536+
512537

513538
@slow
514539
class BaseBnb8bitSerializationTests(Base8bitTests):

0 commit comments

Comments
 (0)