Skip to content

Commit 0632db6

Browse files
committed
add lora tests
1 parent f25934a commit 0632db6

File tree

1 file changed

+134
-0
lines changed

1 file changed

+134
-0
lines changed
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
import sys
16+
import unittest
17+
18+
import torch
19+
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2Tokenizer
20+
21+
from diffusers import (
22+
AutoencoderKLQwenImage,
23+
FlowMatchEulerDiscreteScheduler,
24+
QwenImagePipeline,
25+
QwenImageTransformer2DModel,
26+
)
27+
from diffusers.utils.testing_utils import floats_tensor, is_peft_available, require_peft_backend
28+
29+
30+
if is_peft_available():
31+
pass
32+
33+
sys.path.append(".")
34+
35+
from utils import PeftLoraLoaderMixinTests # noqa: E402
36+
37+
38+
@require_peft_backend
39+
class QwenImageLoRATests(unittest.TestCase, PeftLoraLoaderMixinTests):
40+
pipeline_class = QwenImagePipeline
41+
scheduler_cls = FlowMatchEulerDiscreteScheduler
42+
scheduler_classes = [FlowMatchEulerDiscreteScheduler]
43+
scheduler_kwargs = {}
44+
45+
transformer_kwargs = {
46+
"patch_size": 2,
47+
"in_channels": 16,
48+
"out_channels": 4,
49+
"num_layers": 2,
50+
"attention_head_dim": 16,
51+
"num_attention_heads": 3,
52+
"joint_attention_dim": 16,
53+
"guidance_embeds": False,
54+
"axes_dims_rope": (8, 4, 4),
55+
}
56+
transformer_cls = QwenImageTransformer2DModel
57+
z_dim = 4
58+
vae_kwargs = {
59+
"base_dim": z_dim * 6,
60+
"z_dim": z_dim,
61+
"dim_mult": [1, 2, 4],
62+
"num_res_blocks": 1,
63+
"temperal_downsample": [False, True],
64+
# fmt: off
65+
"latents_mean": [0.0] * 4,
66+
"latents_std": [1.0] * 4,
67+
# fmt: on
68+
}
69+
vae_cls = AutoencoderKLQwenImage
70+
tokenizer_cls, tokenizer_id = Qwen2Tokenizer, "hf-internal-testing/tiny-random-Qwen25VLForCondGen"
71+
text_encoder_cls, text_encoder_id = (
72+
Qwen2_5_VLForConditionalGeneration,
73+
"hf-internal-testing/tiny-random-Qwen25VLForCondGen",
74+
)
75+
denoiser_target_modules = ["to_q", "to_k", "to_v", "to_out.0"]
76+
77+
@property
78+
def output_shape(self):
79+
return (1, 8, 8, 3)
80+
81+
def get_dummy_inputs(self, with_generator=True):
82+
batch_size = 1
83+
sequence_length = 10
84+
num_channels = 4
85+
sizes = (32, 32)
86+
87+
generator = torch.manual_seed(0)
88+
noise = floats_tensor((batch_size, num_channels) + sizes)
89+
input_ids = torch.randint(1, sequence_length, size=(batch_size, sequence_length), generator=generator)
90+
91+
pipeline_inputs = {
92+
"prompt": "A painting of a squirrel eating a burger",
93+
"num_inference_steps": 4,
94+
"guidance_scale": 0.0,
95+
"height": 8,
96+
"width": 8,
97+
"output_type": "np",
98+
}
99+
if with_generator:
100+
pipeline_inputs.update({"generator": generator})
101+
102+
return noise, input_ids, pipeline_inputs
103+
104+
@unittest.skip("Not supported in Qwen Image.")
105+
def test_simple_inference_with_text_denoiser_block_scale(self):
106+
pass
107+
108+
@unittest.skip("Not supported in Qwen Image.")
109+
def test_simple_inference_with_text_denoiser_block_scale_for_all_dict_options(self):
110+
pass
111+
112+
@unittest.skip("Not supported in Qwen Image.")
113+
def test_modify_padding_mode(self):
114+
pass
115+
116+
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
117+
def test_simple_inference_with_partial_text_lora(self):
118+
pass
119+
120+
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
121+
def test_simple_inference_with_text_lora(self):
122+
pass
123+
124+
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
125+
def test_simple_inference_with_text_lora_and_scale(self):
126+
pass
127+
128+
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
129+
def test_simple_inference_with_text_lora_fused(self):
130+
pass
131+
132+
@unittest.skip("Text encoder LoRA is not supported in Qwen Image.")
133+
def test_simple_inference_with_text_lora_save_load(self):
134+
pass

0 commit comments

Comments
 (0)