Skip to content

Commit 4e57aef

Browse files
authored
[Tests] add test suite for SD3 DreamBooth (#8650)
* add a test suite for SD3 DreamBooth * lora suite * style * add checkpointing tests for LoRA * add test to cover train_text_encoder.
1 parent af92869 commit 4e57aef

File tree

2 files changed

+368
-0
lines changed

2 files changed

+368
-0
lines changed
Lines changed: 165 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,165 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import sys
19+
import tempfile
20+
21+
import safetensors
22+
23+
24+
sys.path.append("..")
25+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
26+
27+
28+
logging.basicConfig(level=logging.DEBUG)
29+
30+
logger = logging.getLogger()
31+
stream_handler = logging.StreamHandler(sys.stdout)
32+
logger.addHandler(stream_handler)
33+
34+
35+
class DreamBoothLoRASD3(ExamplesTestsAccelerate):
36+
instance_data_dir = "docs/source/en/imgs"
37+
instance_prompt = "photo"
38+
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
39+
script_path = "examples/dreambooth/train_dreambooth_lora_sd3.py"
40+
41+
def test_dreambooth_lora_sd3(self):
42+
with tempfile.TemporaryDirectory() as tmpdir:
43+
test_args = f"""
44+
{self.script_path}
45+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
46+
--instance_data_dir {self.instance_data_dir}
47+
--instance_prompt {self.instance_prompt}
48+
--resolution 64
49+
--train_batch_size 1
50+
--gradient_accumulation_steps 1
51+
--max_train_steps 2
52+
--learning_rate 5.0e-04
53+
--scale_lr
54+
--lr_scheduler constant
55+
--lr_warmup_steps 0
56+
--output_dir {tmpdir}
57+
""".split()
58+
59+
run_command(self._launch_args + test_args)
60+
# save_pretrained smoke test
61+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
62+
63+
# make sure the state_dict has the correct naming in the parameters.
64+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
65+
is_lora = all("lora" in k for k in lora_state_dict.keys())
66+
self.assertTrue(is_lora)
67+
68+
# when not training the text encoder, all the parameters in the state dict should start
69+
# with `"transformer"` in their names.
70+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
71+
self.assertTrue(starts_with_transformer)
72+
73+
def test_dreambooth_lora_text_encoder_sd3(self):
74+
with tempfile.TemporaryDirectory() as tmpdir:
75+
test_args = f"""
76+
{self.script_path}
77+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
78+
--instance_data_dir {self.instance_data_dir}
79+
--instance_prompt {self.instance_prompt}
80+
--resolution 64
81+
--train_batch_size 1
82+
--train_text_encoder
83+
--gradient_accumulation_steps 1
84+
--max_train_steps 2
85+
--learning_rate 5.0e-04
86+
--scale_lr
87+
--lr_scheduler constant
88+
--lr_warmup_steps 0
89+
--output_dir {tmpdir}
90+
""".split()
91+
92+
run_command(self._launch_args + test_args)
93+
# save_pretrained smoke test
94+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
95+
96+
# make sure the state_dict has the correct naming in the parameters.
97+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
98+
is_lora = all("lora" in k for k in lora_state_dict.keys())
99+
self.assertTrue(is_lora)
100+
101+
starts_with_expected_prefix = all(
102+
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
103+
)
104+
self.assertTrue(starts_with_expected_prefix)
105+
106+
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
107+
with tempfile.TemporaryDirectory() as tmpdir:
108+
test_args = f"""
109+
{self.script_path}
110+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
111+
--instance_data_dir={self.instance_data_dir}
112+
--output_dir={tmpdir}
113+
--instance_prompt={self.instance_prompt}
114+
--resolution=64
115+
--train_batch_size=1
116+
--gradient_accumulation_steps=1
117+
--max_train_steps=6
118+
--checkpoints_total_limit=2
119+
--checkpointing_steps=2
120+
""".split()
121+
122+
run_command(self._launch_args + test_args)
123+
124+
self.assertEqual(
125+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
126+
{"checkpoint-4", "checkpoint-6"},
127+
)
128+
129+
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
130+
with tempfile.TemporaryDirectory() as tmpdir:
131+
test_args = f"""
132+
{self.script_path}
133+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
134+
--instance_data_dir={self.instance_data_dir}
135+
--output_dir={tmpdir}
136+
--instance_prompt={self.instance_prompt}
137+
--resolution=64
138+
--train_batch_size=1
139+
--gradient_accumulation_steps=1
140+
--max_train_steps=4
141+
--checkpointing_steps=2
142+
""".split()
143+
144+
run_command(self._launch_args + test_args)
145+
146+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
147+
148+
resume_run_args = f"""
149+
{self.script_path}
150+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
151+
--instance_data_dir={self.instance_data_dir}
152+
--output_dir={tmpdir}
153+
--instance_prompt={self.instance_prompt}
154+
--resolution=64
155+
--train_batch_size=1
156+
--gradient_accumulation_steps=1
157+
--max_train_steps=8
158+
--checkpointing_steps=2
159+
--resume_from_checkpoint=checkpoint-4
160+
--checkpoints_total_limit=2
161+
""".split()
162+
163+
run_command(self._launch_args + resume_run_args)
164+
165+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
# coding=utf-8
2+
# Copyright 2024 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import logging
17+
import os
18+
import shutil
19+
import sys
20+
import tempfile
21+
22+
from diffusers import DiffusionPipeline, SD3Transformer2DModel
23+
24+
25+
sys.path.append("..")
26+
from test_examples_utils import ExamplesTestsAccelerate, run_command # noqa: E402
27+
28+
29+
logging.basicConfig(level=logging.DEBUG)
30+
31+
logger = logging.getLogger()
32+
stream_handler = logging.StreamHandler(sys.stdout)
33+
logger.addHandler(stream_handler)
34+
35+
36+
class DreamBoothSD3(ExamplesTestsAccelerate):
37+
instance_data_dir = "docs/source/en/imgs"
38+
instance_prompt = "photo"
39+
pretrained_model_name_or_path = "hf-internal-testing/tiny-sd3-pipe"
40+
script_path = "examples/dreambooth/train_dreambooth_sd3.py"
41+
42+
def test_dreambooth(self):
43+
with tempfile.TemporaryDirectory() as tmpdir:
44+
test_args = f"""
45+
{self.script_path}
46+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
47+
--instance_data_dir {self.instance_data_dir}
48+
--instance_prompt {self.instance_prompt}
49+
--resolution 64
50+
--train_batch_size 1
51+
--gradient_accumulation_steps 1
52+
--max_train_steps 2
53+
--learning_rate 5.0e-04
54+
--scale_lr
55+
--lr_scheduler constant
56+
--lr_warmup_steps 0
57+
--output_dir {tmpdir}
58+
""".split()
59+
60+
run_command(self._launch_args + test_args)
61+
# save_pretrained smoke test
62+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "transformer", "diffusion_pytorch_model.safetensors")))
63+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "scheduler", "scheduler_config.json")))
64+
65+
def test_dreambooth_checkpointing(self):
66+
with tempfile.TemporaryDirectory() as tmpdir:
67+
# Run training script with checkpointing
68+
# max_train_steps == 4, checkpointing_steps == 2
69+
# Should create checkpoints at steps 2, 4
70+
71+
initial_run_args = f"""
72+
{self.script_path}
73+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
74+
--instance_data_dir {self.instance_data_dir}
75+
--instance_prompt {self.instance_prompt}
76+
--resolution 64
77+
--train_batch_size 1
78+
--gradient_accumulation_steps 1
79+
--max_train_steps 4
80+
--learning_rate 5.0e-04
81+
--scale_lr
82+
--lr_scheduler constant
83+
--lr_warmup_steps 0
84+
--output_dir {tmpdir}
85+
--checkpointing_steps=2
86+
--seed=0
87+
""".split()
88+
89+
run_command(self._launch_args + initial_run_args)
90+
91+
# check can run the original fully trained output pipeline
92+
pipe = DiffusionPipeline.from_pretrained(tmpdir)
93+
pipe(self.instance_prompt, num_inference_steps=1)
94+
95+
# check checkpoint directories exist
96+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
97+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
98+
99+
# check can run an intermediate checkpoint
100+
transformer = SD3Transformer2DModel.from_pretrained(tmpdir, subfolder="checkpoint-2/transformer")
101+
pipe = DiffusionPipeline.from_pretrained(self.pretrained_model_name_or_path, transformer=transformer)
102+
pipe(self.instance_prompt, num_inference_steps=1)
103+
104+
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
105+
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
106+
107+
# Run training script for 7 total steps resuming from checkpoint 4
108+
109+
resume_run_args = f"""
110+
{self.script_path}
111+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
112+
--instance_data_dir {self.instance_data_dir}
113+
--instance_prompt {self.instance_prompt}
114+
--resolution 64
115+
--train_batch_size 1
116+
--gradient_accumulation_steps 1
117+
--max_train_steps 6
118+
--learning_rate 5.0e-04
119+
--scale_lr
120+
--lr_scheduler constant
121+
--lr_warmup_steps 0
122+
--output_dir {tmpdir}
123+
--checkpointing_steps=2
124+
--resume_from_checkpoint=checkpoint-4
125+
--seed=0
126+
""".split()
127+
128+
run_command(self._launch_args + resume_run_args)
129+
130+
# check can run new fully trained pipeline
131+
pipe = DiffusionPipeline.from_pretrained(tmpdir)
132+
pipe(self.instance_prompt, num_inference_steps=1)
133+
134+
# check old checkpoints do not exist
135+
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
136+
137+
# check new checkpoints exist
138+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-4")))
139+
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-6")))
140+
141+
def test_dreambooth_checkpointing_checkpoints_total_limit(self):
142+
with tempfile.TemporaryDirectory() as tmpdir:
143+
test_args = f"""
144+
{self.script_path}
145+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
146+
--instance_data_dir={self.instance_data_dir}
147+
--output_dir={tmpdir}
148+
--instance_prompt={self.instance_prompt}
149+
--resolution=64
150+
--train_batch_size=1
151+
--gradient_accumulation_steps=1
152+
--max_train_steps=6
153+
--checkpoints_total_limit=2
154+
--checkpointing_steps=2
155+
""".split()
156+
157+
run_command(self._launch_args + test_args)
158+
159+
self.assertEqual(
160+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
161+
{"checkpoint-4", "checkpoint-6"},
162+
)
163+
164+
def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
165+
with tempfile.TemporaryDirectory() as tmpdir:
166+
test_args = f"""
167+
{self.script_path}
168+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
169+
--instance_data_dir={self.instance_data_dir}
170+
--output_dir={tmpdir}
171+
--instance_prompt={self.instance_prompt}
172+
--resolution=64
173+
--train_batch_size=1
174+
--gradient_accumulation_steps=1
175+
--max_train_steps=4
176+
--checkpointing_steps=2
177+
""".split()
178+
179+
run_command(self._launch_args + test_args)
180+
181+
self.assertEqual(
182+
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
183+
{"checkpoint-2", "checkpoint-4"},
184+
)
185+
186+
resume_run_args = f"""
187+
{self.script_path}
188+
--pretrained_model_name_or_path={self.pretrained_model_name_or_path}
189+
--instance_data_dir={self.instance_data_dir}
190+
--output_dir={tmpdir}
191+
--instance_prompt={self.instance_prompt}
192+
--resolution=64
193+
--train_batch_size=1
194+
--gradient_accumulation_steps=1
195+
--max_train_steps=8
196+
--checkpointing_steps=2
197+
--resume_from_checkpoint=checkpoint-4
198+
--checkpoints_total_limit=2
199+
""".split()
200+
201+
run_command(self._launch_args + resume_run_args)
202+
203+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

0 commit comments

Comments
 (0)