Skip to content

Commit ef83f04

Browse files
committed
add tests
1 parent 69ae746 commit ef83f04

File tree

1 file changed

+248
-0
lines changed

1 file changed

+248
-0
lines changed
Lines changed: 248 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import json
17+
import logging
18+
import os
19+
import sys
20+
import tempfile
21+
22+
import safetensors
23+
24+
from diffusers.loaders.lora_base import LORA_ADAPTER_METADATA_KEY
25+
26+
27+
sys.path.append("..")
28+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
29+
30+
31+
logging.basicConfig(level=logging.DEBUG)
32+
33+
logger = logging.getLogger()
34+
stream_handler = logging.StreamHandler(sys.stdout)
35+
logger.addHandler(stream_handler)
36+
37+
38+
class DreamBoothLoRAQwenImage(ExamplesTestsAccelerate):
39+
instance_data_dir = "docs/source/en/imgs"
40+
instance_prompt = "photo"
41+
pretrained_model_name_or_path = "hf-internal-testing/tiny-qwenimage-pipe"
42+
script_path = "examples/dreambooth/train_dreambooth_lora_qwen_image.py"
43+
transformer_layer_type = "single_transformer_blocks.0.attn.to_k"
44+
45+
def test_dreambooth_lora_qwen(self):
46+
with tempfile.TemporaryDirectory() as tmpdir:
47+
test_args = f"""
48+
{self.script_path}
49+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
50+
--instance_data_dir {self.instance_data_dir}
51+
--instance_prompt {self.instance_prompt}
52+
--resolution 64
53+
--train_batch_size 1
54+
--gradient_accumulation_steps 1
55+
--max_train_steps 2
56+
--learning_rate 5.0e-04
57+
--scale_lr
58+
--lr_scheduler constant
59+
--lr_warmup_steps 0
60+
--output_dir {tmpdir}
61+
""".split()
62+
63+
run_command(self._launch_args + test_args)
64+
# save_pretrained smoke test
65+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
66+
67+
# make sure the state_dict has the correct naming in the parameters.
68+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
69+
is_lora = all("lora" in k for k in lora_state_dict.keys())
70+
self.assertTrue(is_lora)
71+
72+
# when not training the text encoder, all the parameters in the state dict should start
73+
# with `"transformer"` in their names.
74+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
75+
self.assertTrue(starts_with_transformer)
76+
77+
def test_dreambooth_lora_latent_caching(self):
78+
with tempfile.TemporaryDirectory() as tmpdir:
79+
test_args = f"""
80+
{self.script_path}
81+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
82+
--instance_data_dir {self.instance_data_dir}
83+
--instance_prompt {self.instance_prompt}
84+
--resolution 64
85+
--train_batch_size 1
86+
--gradient_accumulation_steps 1
87+
--max_train_steps 2
88+
--cache_latents
89+
--learning_rate 5.0e-04
90+
--scale_lr
91+
--lr_scheduler constant
92+
--lr_warmup_steps 0
93+
--output_dir {tmpdir}
94+
""".split()
95+
96+
run_command(self._launch_args + test_args)
97+
# save_pretrained smoke test
98+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
99+
100+
# make sure the state_dict has the correct naming in the parameters.
101+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
102+
is_lora = all("lora" in k for k in lora_state_dict.keys())
103+
self.assertTrue(is_lora)
104+
105+
# when not training the text encoder, all the parameters in the state dict should start
106+
# with `"transformer"` in their names.
107+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
108+
self.assertTrue(starts_with_transformer)
109+
110+
def test_dreambooth_lora_layers(self):
111+
with tempfile.TemporaryDirectory() as tmpdir:
112+
test_args = f"""
113+
{self.script_path}
114+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
115+
--instance_data_dir {self.instance_data_dir}
116+
--instance_prompt {self.instance_prompt}
117+
--resolution 64
118+
--train_batch_size 1
119+
--gradient_accumulation_steps 1
120+
--max_train_steps 2
121+
--cache_latents
122+
--learning_rate 5.0e-04
123+
--scale_lr
124+
--lora_layers {self.transformer_layer_type}
125+
--lr_scheduler constant
126+
--lr_warmup_steps 0
127+
--output_dir {tmpdir}
128+
""".split()
129+
130+
run_command(self._launch_args + test_args)
131+
# save_pretrained smoke test
132+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
133+
134+
# make sure the state_dict has the correct naming in the parameters.
135+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
136+
is_lora = all("lora" in k for k in lora_state_dict.keys())
137+
self.assertTrue(is_lora)
138+
139+
# when not training the text encoder, all the parameters in the state dict should start
140+
# with `"transformer"` in their names. In this test, we only params of
141+
# transformer.single_transformer_blocks.0.attn.to_k should be in the state dict
142+
starts_with_transformer = all(
143+
key.startswith("transformer.single_transformer_blocks.0.attn.to_k") for key in lora_state_dict.keys()
144+
)
145+
self.assertTrue(starts_with_transformer)
146+
147+
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit(self):
148+
with tempfile.TemporaryDirectory() as tmpdir:
149+
test_args = f"""
150+
{self.script_path}
151+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
152+
--instance_data_dir={self.instance_data_dir}
153+
--output_dir={tmpdir}
154+
--instance_prompt={self.instance_prompt}
155+
--resolution=64
156+
--train_batch_size=1
157+
--gradient_accumulation_steps=1
158+
--max_train_steps=6
159+
--checkpoints_total_limit=2
160+
--checkpointing_steps=2
161+
""".split()
162+
163+
run_command(self._launch_args + test_args)
164+
165+
self.assertEqual(
166+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
167+
{"checkpoint-4", "checkpoint-6"},
168+
)
169+
170+
def test_dreambooth_lora_qwen_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
171+
with tempfile.TemporaryDirectory() as tmpdir:
172+
test_args = f"""
173+
{self.script_path}
174+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
175+
--instance_data_dir={self.instance_data_dir}
176+
--output_dir={tmpdir}
177+
--instance_prompt={self.instance_prompt}
178+
--resolution=64
179+
--train_batch_size=1
180+
--gradient_accumulation_steps=1
181+
--max_train_steps=4
182+
--checkpointing_steps=2
183+
""".split()
184+
185+
run_command(self._launch_args + test_args)
186+
187+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
188+
189+
resume_run_args = f"""
190+
{self.script_path}
191+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
192+
--instance_data_dir={self.instance_data_dir}
193+
--output_dir={tmpdir}
194+
--instance_prompt={self.instance_prompt}
195+
--resolution=64
196+
--train_batch_size=1
197+
--gradient_accumulation_steps=1
198+
--max_train_steps=8
199+
--checkpointing_steps=2
200+
--resume_from_checkpoint=checkpoint-4
201+
--checkpoints_total_limit=2
202+
""".split()
203+
204+
run_command(self._launch_args + resume_run_args)
205+
206+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
207+
208+
def test_dreambooth_lora_with_metadata(self):
209+
# Use a `lora_alpha` that is different from `rank`.
210+
lora_alpha = 8
211+
rank = 4
212+
with tempfile.TemporaryDirectory() as tmpdir:
213+
test_args = f"""
214+
{self.script_path}
215+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
216+
--instance_data_dir {self.instance_data_dir}
217+
--instance_prompt {self.instance_prompt}
218+
--resolution 64
219+
--train_batch_size 1
220+
--gradient_accumulation_steps 1
221+
--max_train_steps 2
222+
--lora_alpha={lora_alpha}
223+
--rank={rank}
224+
--learning_rate 5.0e-04
225+
--scale_lr
226+
--lr_scheduler constant
227+
--lr_warmup_steps 0
228+
--output_dir {tmpdir}
229+
""".split()
230+
231+
run_command(self._launch_args + test_args)
232+
# save_pretrained smoke test
233+
state_dict_file = os.path.join(tmpdir, "pytorch_lora_weights.safetensors")
234+
self.assertTrue(os.path.isfile(state_dict_file))
235+
236+
# Check if the metadata was properly serialized.
237+
with safetensors.torch.safe_open(state_dict_file, framework="pt", device="cpu") as f:
238+
metadata = f.metadata() or {}
239+
240+
metadata.pop("format", None)
241+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
242+
if raw:
243+
raw = json.loads(raw)
244+
245+
loaded_lora_alpha = raw["transformer.lora_alpha"]
246+
self.assertTrue(loaded_lora_alpha == lora_alpha)
247+
loaded_lora_rank = raw["transformer.r"]
248+
self.assertTrue(loaded_lora_rank == rank)

0 commit comments

Comments
 (0)