Skip to content

Commit 22e9ae8

Browse files
committed
add tests
1 parent c439c89 commit 22e9ae8

File tree

1 file changed

+206
-0
lines changed

1 file changed

+206
-0
lines changed
Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRAHiDreamImage(ExamplesTestsAccelerate):
36+
instance_data_dir = "docs/source/en/imgs"
37+
pretrained_model_name_or_path = "hf-internal-testing/tiny-hidream-i1-pipe"
38+
script_path = "examples/dreambooth/train_dreambooth_lora_hidream.py"
39+
transformer_layer_type = "double_stream_blocks.0.block.attn1.to_k"
40+
41+
def test_dreambooth_lora_hidream(self):
42+
with tempfile.TemporaryDirectory() as tmpdir:
43+
test_args = f"""
44+
{self.script_path}
45+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
46+
--instance_data_dir {self.instance_data_dir}
47+
--resolution 32
48+
--train_batch_size 1
49+
--gradient_accumulation_steps 1
50+
--max_train_steps 2
51+
--learning_rate 5.0e-04
52+
--scale_lr
53+
--lr_scheduler constant
54+
--lr_warmup_steps 0
55+
--output_dir {tmpdir}
56+
--max_sequence_length 16
57+
""".split()
58+
59+
test_args.extend(["--instance_prompt", ""])
60+
run_command(self._launch_args + test_args)
61+
# save_pretrained smoke test
62+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
63+
64+
# make sure the state_dict has the correct naming in the parameters.
65+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
66+
is_lora = all("lora" in k for k in lora_state_dict.keys())
67+
self.assertTrue(is_lora)
68+
69+
# when not training the text encoder, all the parameters in the state dict should start
70+
# with `"transformer"` in their names.
71+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
72+
self.assertTrue(starts_with_transformer)
73+
74+
def test_dreambooth_lora_latent_caching(self):
75+
with tempfile.TemporaryDirectory() as tmpdir:
76+
test_args = f"""
77+
{self.script_path}
78+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
79+
--instance_data_dir {self.instance_data_dir}
80+
--resolution 32
81+
--train_batch_size 1
82+
--gradient_accumulation_steps 1
83+
--max_train_steps 2
84+
--cache_latents
85+
--learning_rate 5.0e-04
86+
--scale_lr
87+
--lr_scheduler constant
88+
--lr_warmup_steps 0
89+
--output_dir {tmpdir}
90+
--max_sequence_length 16
91+
""".split()
92+
93+
test_args.extend(["--instance_prompt", ""])
94+
run_command(self._launch_args + test_args)
95+
# save_pretrained smoke test
96+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
97+
98+
# make sure the state_dict has the correct naming in the parameters.
99+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
100+
is_lora = all("lora" in k for k in lora_state_dict.keys())
101+
self.assertTrue(is_lora)
102+
103+
# when not training the text encoder, all the parameters in the state dict should start
104+
# with `"transformer"` in their names.
105+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
106+
self.assertTrue(starts_with_transformer)
107+
108+
def test_dreambooth_lora_layers(self):
109+
with tempfile.TemporaryDirectory() as tmpdir:
110+
test_args = f"""
111+
{self.script_path}
112+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
113+
--instance_data_dir {self.instance_data_dir}
114+
--resolution 32
115+
--train_batch_size 1
116+
--gradient_accumulation_steps 1
117+
--max_train_steps 2
118+
--cache_latents
119+
--learning_rate 5.0e-04
120+
--scale_lr
121+
--lora_layers {self.transformer_layer_type}
122+
--lr_scheduler constant
123+
--lr_warmup_steps 0
124+
--output_dir {tmpdir}
125+
--max_sequence_length 16
126+
""".split()
127+
128+
test_args.extend(["--instance_prompt", ""])
129+
run_command(self._launch_args + test_args)
130+
# save_pretrained smoke test
131+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
132+
133+
# make sure the state_dict has the correct naming in the parameters.
134+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
135+
is_lora = all("lora" in k for k in lora_state_dict.keys())
136+
self.assertTrue(is_lora)
137+
138+
# when not training the text encoder, all the parameters in the state dict should start
139+
# with `"transformer"` in their names. In this test, we only params of
140+
# `self.transformer_layer_type` should be in the state dict.
141+
starts_with_transformer = all(self.transformer_layer_type in key for key in lora_state_dict)
142+
self.assertTrue(starts_with_transformer)
143+
144+
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit(self):
145+
with tempfile.TemporaryDirectory() as tmpdir:
146+
test_args = f"""
147+
{self.script_path}
148+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
149+
--instance_data_dir={self.instance_data_dir}
150+
--output_dir={tmpdir}
151+
--resolution=32
152+
--train_batch_size=1
153+
--gradient_accumulation_steps=1
154+
--max_train_steps=6
155+
--checkpoints_total_limit=2
156+
--checkpointing_steps=2
157+
--max_sequence_length 16
158+
""".split()
159+
160+
test_args.extend(["--instance_prompt", ""])
161+
run_command(self._launch_args + test_args)
162+
163+
self.assertEqual(
164+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
165+
{"checkpoint-4", "checkpoint-6"},
166+
)
167+
168+
def test_dreambooth_lora_hidream_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
169+
with tempfile.TemporaryDirectory() as tmpdir:
170+
test_args = f"""
171+
{self.script_path}
172+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
173+
--instance_data_dir={self.instance_data_dir}
174+
--output_dir={tmpdir}
175+
--resolution=32
176+
--train_batch_size=1
177+
--gradient_accumulation_steps=1
178+
--max_train_steps=4
179+
--checkpointing_steps=2
180+
--max_sequence_length 166
181+
""".split()
182+
183+
test_args.extend(["--instance_prompt", ""])
184+
run_command(self._launch_args + test_args)
185+
186+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
187+
188+
resume_run_args = f"""
189+
{self.script_path}
190+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
191+
--instance_data_dir={self.instance_data_dir}
192+
--output_dir={tmpdir}
193+
--resolution=32
194+
--train_batch_size=1
195+
--gradient_accumulation_steps=1
196+
--max_train_steps=8
197+
--checkpointing_steps=2
198+
--resume_from_checkpoint=checkpoint-4
199+
--checkpoints_total_limit=2
200+
--max_sequence_length 16
201+
""".split()
202+
203+
resume_run_args.extend(["--instance_prompt", ""])
204+
run_command(self._launch_args + resume_run_args)
205+
206+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

0 commit comments

Comments
 (0)