Skip to content

Commit cb265ad

Browse files
committed
add tests
1 parent c41dfff commit cb265ad

File tree

1 file changed

+279
-0
lines changed

1 file changed

+279
-0
lines changed
Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRAFluxAdvanced(ExamplesTestsAccelerate):
36+
instance_data_dir = "docs/source/en/imgs"
37+
instance_prompt = "photo"
38+
pretrained_model_name_or_path = "hf-internal-testing/tiny-flux-pipe"
39+
script_path = "examples/advanced_diffusion_training/train_dreambooth_lora_flux_advanced.py"
40+
41+
def test_dreambooth_lora_flux(self):
42+
with tempfile.TemporaryDirectory() as tmpdir:
43+
test_args = f"""
44+
{self.script_path}
45+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
46+
--instance_data_dir {self.instance_data_dir}
47+
--instance_prompt {self.instance_prompt}
48+
--resolution 64
49+
--train_batch_size 1
50+
--gradient_accumulation_steps 1
51+
--max_train_steps 2
52+
--learning_rate 5.0e-04
53+
--scale_lr
54+
--lr_scheduler constant
55+
--lr_warmup_steps 0
56+
--output_dir {tmpdir}
57+
""".split()
58+
59+
run_command(self._launch_args + test_args)
60+
# save_pretrained smoke test
61+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
62+
63+
# make sure the state_dict has the correct naming in the parameters.
64+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
65+
is_lora = all("lora" in k for k in lora_state_dict.keys())
66+
self.assertTrue(is_lora)
67+
68+
# when not training the text encoder, all the parameters in the state dict should start
69+
# with `"transformer"` in their names.
70+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
71+
self.assertTrue(starts_with_transformer)
72+
73+
def test_dreambooth_lora_text_encoder_flux(self):
74+
with tempfile.TemporaryDirectory() as tmpdir:
75+
test_args = f"""
76+
{self.script_path}
77+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
78+
--instance_data_dir {self.instance_data_dir}
79+
--instance_prompt {self.instance_prompt}
80+
--resolution 64
81+
--train_batch_size 1
82+
--train_text_encoder
83+
--gradient_accumulation_steps 1
84+
--max_train_steps 2
85+
--learning_rate 5.0e-04
86+
--scale_lr
87+
--lr_scheduler constant
88+
--lr_warmup_steps 0
89+
--output_dir {tmpdir}
90+
""".split()
91+
92+
run_command(self._launch_args + test_args)
93+
# save_pretrained smoke test
94+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
95+
96+
# make sure the state_dict has the correct naming in the parameters.
97+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
98+
is_lora = all("lora" in k for k in lora_state_dict.keys())
99+
self.assertTrue(is_lora)
100+
101+
starts_with_expected_prefix = all(
102+
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
103+
)
104+
self.assertTrue(starts_with_expected_prefix)
105+
106+
def test_dreambooth_lora_pivotal_tuning_flux_clip(self):
107+
with tempfile.TemporaryDirectory() as tmpdir:
108+
test_args = f"""
109+
{self.script_path}
110+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
111+
--instance_data_dir {self.instance_data_dir}
112+
--instance_prompt {self.instance_prompt}
113+
--resolution 64
114+
--train_batch_size 1
115+
--train_text_encoder_ti
116+
--gradient_accumulation_steps 1
117+
--max_train_steps 2
118+
--learning_rate 5.0e-04
119+
--scale_lr
120+
--lr_scheduler constant
121+
--lr_warmup_steps 0
122+
--output_dir {tmpdir}
123+
""".split()
124+
125+
run_command(self._launch_args + test_args)
126+
# save_pretrained smoke test
127+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
128+
# make sure embeddings were also saved
129+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")))
130+
131+
# make sure the state_dict has the correct naming in the parameters.
132+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
133+
is_lora = all("lora" in k for k in lora_state_dict.keys())
134+
self.assertTrue(is_lora)
135+
136+
# make sure the state_dict has the correct naming in the parameters.
137+
textual_inversion_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors"))
138+
is_clip = all("clip_l" in k for k in textual_inversion_state_dict.keys())
139+
self.assertTrue(is_clip)
140+
141+
# when performing pivotal tuning, all the parameters in the state dict should start
142+
# with `"transformer"` in their names.
143+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
144+
self.assertTrue(starts_with_transformer)
145+
146+
def test_dreambooth_lora_pivotal_tuning_flux_clip_t5(self):
147+
with tempfile.TemporaryDirectory() as tmpdir:
148+
test_args = f"""
149+
{self.script_path}
150+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
151+
--instance_data_dir {self.instance_data_dir}
152+
--instance_prompt {self.instance_prompt}
153+
--resolution 64
154+
--train_batch_size 1
155+
--train_text_encoder_ti
156+
--enable_t5_ti
157+
--gradient_accumulation_steps 1
158+
--max_train_steps 2
159+
--learning_rate 5.0e-04
160+
--scale_lr
161+
--lr_scheduler constant
162+
--lr_warmup_steps 0
163+
--output_dir {tmpdir}
164+
""".split()
165+
166+
run_command(self._launch_args + test_args)
167+
# save_pretrained smoke test
168+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
169+
# make sure embeddings were also saved
170+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors")))
171+
172+
# make sure the state_dict has the correct naming in the parameters.
173+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
174+
is_lora = all("lora" in k for k in lora_state_dict.keys())
175+
self.assertTrue(is_lora)
176+
177+
# make sure the state_dict has the correct naming in the parameters.
178+
textual_inversion_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, f"{tmpdir}_emb.safetensors"))
179+
is_te = all(("clip_l" in k or "t5" in k) for k in textual_inversion_state_dict.keys())
180+
self.assertTrue(is_te)
181+
182+
# when performing pivotal tuning, all the parameters in the state dict should start
183+
# with `"transformer"` in their names.
184+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
185+
self.assertTrue(starts_with_transformer)
186+
187+
def test_dreambooth_lora_latent_caching(self):
188+
with tempfile.TemporaryDirectory() as tmpdir:
189+
test_args = f"""
190+
{self.script_path}
191+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
192+
--instance_data_dir {self.instance_data_dir}
193+
--instance_prompt {self.instance_prompt}
194+
--resolution 64
195+
--train_batch_size 1
196+
--gradient_accumulation_steps 1
197+
--max_train_steps 2
198+
--cache_latents
199+
--learning_rate 5.0e-04
200+
--scale_lr
201+
--lr_scheduler constant
202+
--lr_warmup_steps 0
203+
--output_dir {tmpdir}
204+
""".split()
205+
206+
run_command(self._launch_args + test_args)
207+
# save_pretrained smoke test
208+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
209+
210+
# make sure the state_dict has the correct naming in the parameters.
211+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
212+
is_lora = all("lora" in k for k in lora_state_dict.keys())
213+
self.assertTrue(is_lora)
214+
215+
# when not training the text encoder, all the parameters in the state dict should start
216+
# with `"transformer"` in their names.
217+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
218+
self.assertTrue(starts_with_transformer)
219+
220+
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit(self):
221+
with tempfile.TemporaryDirectory() as tmpdir:
222+
test_args = f"""
223+
{self.script_path}
224+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
225+
--instance_data_dir={self.instance_data_dir}
226+
--output_dir={tmpdir}
227+
--instance_prompt={self.instance_prompt}
228+
--resolution=64
229+
--train_batch_size=1
230+
--gradient_accumulation_steps=1
231+
--max_train_steps=6
232+
--checkpoints_total_limit=2
233+
--checkpointing_steps=2
234+
""".split()
235+
236+
run_command(self._launch_args + test_args)
237+
238+
self.assertEqual(
239+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
240+
{"checkpoint-4", "checkpoint-6"},
241+
)
242+
243+
def test_dreambooth_lora_flux_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
244+
with tempfile.TemporaryDirectory() as tmpdir:
245+
test_args = f"""
246+
{self.script_path}
247+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
248+
--instance_data_dir={self.instance_data_dir}
249+
--output_dir={tmpdir}
250+
--instance_prompt={self.instance_prompt}
251+
--resolution=64
252+
--train_batch_size=1
253+
--gradient_accumulation_steps=1
254+
--max_train_steps=4
255+
--checkpointing_steps=2
256+
""".split()
257+
258+
run_command(self._launch_args + test_args)
259+
260+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
261+
262+
resume_run_args = f"""
263+
{self.script_path}
264+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
265+
--instance_data_dir={self.instance_data_dir}
266+
--output_dir={tmpdir}
267+
--instance_prompt={self.instance_prompt}
268+
--resolution=64
269+
--train_batch_size=1
270+
--gradient_accumulation_steps=1
271+
--max_train_steps=8
272+
--checkpointing_steps=2
273+
--resume_from_checkpoint=checkpoint-4
274+
--checkpoints_total_limit=2
275+
""".split()
276+
277+
run_command(self._launch_args + resume_run_args)
278+
279+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

0 commit comments

Comments
 (0)