Skip to content

Commit f4b0b26

Browse files
authored
[Tests] Speed up example tests (#6319)
* remove validation args from textual onverson tests * reduce number of train steps in textual inversion tests * fix: directories. * debig * fix: directories. * remove validation tests from textual onversion * try reducing the time of test_text_to_image_checkpointing_use_ema * fix: directories * speed up test_text_to_image_checkpointing * speed up test_text_to_image_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * fix * speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * set checkpoints_total_limit to 2. * test_text_to_image_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints speed up * speed up test_unconditional_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * debug * fix: directories. * speed up test_instruct_pix2pix_checkpointing_checkpoints_total_limit * speed up: test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * speed up test_controlnet_sdxl * speed up dreambooth tests * speed up test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * speed up test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints * speed up test_text_to_image_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit * speed up # checkpoint-2 should have been deleted * speed up examples/text_to_image/test_text_to_image.py::TextToImage::test_text_to_image_checkpointing_checkpoints_total_limit * additional speed ups * style
1 parent 89459a5 commit f4b0b26

File tree

9 files changed

+117
-164
lines changed

9 files changed

+117
-164
lines changed

examples/controlnet/test_controlnet.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -65,15 +65,15 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check
6565
--train_batch_size=1
6666
--gradient_accumulation_steps=1
6767
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
68-
--max_train_steps=9
68+
--max_train_steps=6
6969
--checkpointing_steps=2
7070
""".split()
7171

7272
run_command(self._launch_args + test_args)
7373

7474
self.assertEqual(
7575
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
76-
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
76+
{"checkpoint-2", "checkpoint-4", "checkpoint-6"},
7777
)
7878

7979
resume_run_args = f"""
@@ -85,18 +85,15 @@ def test_controlnet_checkpointing_checkpoints_total_limit_removes_multiple_check
8585
--train_batch_size=1
8686
--gradient_accumulation_steps=1
8787
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet
88-
--max_train_steps=11
88+
--max_train_steps=8
8989
--checkpointing_steps=2
90-
--resume_from_checkpoint=checkpoint-8
91-
--checkpoints_total_limit=3
90+
--resume_from_checkpoint=checkpoint-6
91+
--checkpoints_total_limit=2
9292
""".split()
9393

9494
run_command(self._launch_args + resume_run_args)
9595

96-
self.assertEqual(
97-
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
98-
{"checkpoint-8", "checkpoint-10", "checkpoint-12"},
99-
)
96+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
10097

10198

10299
class ControlNetSDXL(ExamplesTestsAccelerate):
@@ -111,7 +108,7 @@ def test_controlnet_sdxl(self):
111108
--train_batch_size=1
112109
--gradient_accumulation_steps=1
113110
--controlnet_model_name_or_path=hf-internal-testing/tiny-controlnet-sdxl
114-
--max_train_steps=9
111+
--max_train_steps=4
115112
--checkpointing_steps=2
116113
""".split()
117114

examples/custom_diffusion/test_custom_diffusion.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -76,10 +76,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit(self):
7676

7777
run_command(self._launch_args + test_args)
7878

79-
self.assertEqual(
80-
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
81-
{"checkpoint-4", "checkpoint-6"},
82-
)
79+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
8380

8481
def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple_checkpoints(self):
8582
with tempfile.TemporaryDirectory() as tmpdir:
@@ -93,7 +90,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
9390
--train_batch_size=1
9491
--modifier_token=<new1>
9592
--dataloader_num_workers=0
96-
--max_train_steps=9
93+
--max_train_steps=4
9794
--checkpointing_steps=2
9895
--no_safe_serialization
9996
""".split()
@@ -102,7 +99,7 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
10299

103100
self.assertEqual(
104101
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
105-
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
102+
{"checkpoint-2", "checkpoint-4"},
106103
)
107104

108105
resume_run_args = f"""
@@ -115,16 +112,13 @@ def test_custom_diffusion_checkpointing_checkpoints_total_limit_removes_multiple
115112
--train_batch_size=1
116113
--modifier_token=<new1>
117114
--dataloader_num_workers=0
118-
--max_train_steps=11
115+
--max_train_steps=8
119116
--checkpointing_steps=2
120-
--resume_from_checkpoint=checkpoint-8
121-
--checkpoints_total_limit=3
117+
--resume_from_checkpoint=checkpoint-4
118+
--checkpoints_total_limit=2
122119
--no_safe_serialization
123120
""".split()
124121

125122
run_command(self._launch_args + resume_run_args)
126123

127-
self.assertEqual(
128-
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
129-
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
130-
)
124+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

examples/dreambooth/test_dreambooth.py

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def test_dreambooth_checkpointing(self):
8989

9090
with tempfile.TemporaryDirectory() as tmpdir:
9191
# Run training script with checkpointing
92-
# max_train_steps == 5, checkpointing_steps == 2
92+
# max_train_steps == 4, checkpointing_steps == 2
9393
# Should create checkpoints at steps 2, 4
9494

9595
initial_run_args = f"""
@@ -100,7 +100,7 @@ def test_dreambooth_checkpointing(self):
100100
--resolution 64
101101
--train_batch_size 1
102102
--gradient_accumulation_steps 1
103-
--max_train_steps 5
103+
--max_train_steps 4
104104
--learning_rate 5.0e-04
105105
--scale_lr
106106
--lr_scheduler constant
@@ -114,7 +114,7 @@ def test_dreambooth_checkpointing(self):
114114

115115
# check can run the original fully trained output pipeline
116116
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
117-
pipe(instance_prompt, num_inference_steps=2)
117+
pipe(instance_prompt, num_inference_steps=1)
118118

119119
# check checkpoint directories exist
120120
self.assertTrue(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
@@ -123,7 +123,7 @@ def test_dreambooth_checkpointing(self):
123123
# check can run an intermediate checkpoint
124124
unet = UNet2DConditionModel.from_pretrained(tmpdir, subfolder="checkpoint-2/unet")
125125
pipe = DiffusionPipeline.from_pretrained(pretrained_model_name_or_path, unet=unet, safety_checker=None)
126-
pipe(instance_prompt, num_inference_steps=2)
126+
pipe(instance_prompt, num_inference_steps=1)
127127

128128
# Remove checkpoint 2 so that we can check only later checkpoints exist after resuming
129129
shutil.rmtree(os.path.join(tmpdir, "checkpoint-2"))
@@ -138,7 +138,7 @@ def test_dreambooth_checkpointing(self):
138138
--resolution 64
139139
--train_batch_size 1
140140
--gradient_accumulation_steps 1
141-
--max_train_steps 7
141+
--max_train_steps 6
142142
--learning_rate 5.0e-04
143143
--scale_lr
144144
--lr_scheduler constant
@@ -153,7 +153,7 @@ def test_dreambooth_checkpointing(self):
153153

154154
# check can run new fully trained pipeline
155155
pipe = DiffusionPipeline.from_pretrained(tmpdir, safety_checker=None)
156-
pipe(instance_prompt, num_inference_steps=2)
156+
pipe(instance_prompt, num_inference_steps=1)
157157

158158
# check old checkpoints do not exist
159159
self.assertFalse(os.path.isdir(os.path.join(tmpdir, "checkpoint-2")))
@@ -196,15 +196,15 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check
196196
--resolution=64
197197
--train_batch_size=1
198198
--gradient_accumulation_steps=1
199-
--max_train_steps=9
199+
--max_train_steps=4
200200
--checkpointing_steps=2
201201
""".split()
202202

203203
run_command(self._launch_args + test_args)
204204

205205
self.assertEqual(
206206
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
207-
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
207+
{"checkpoint-2", "checkpoint-4"},
208208
)
209209

210210
resume_run_args = f"""
@@ -216,15 +216,12 @@ def test_dreambooth_checkpointing_checkpoints_total_limit_removes_multiple_check
216216
--resolution=64
217217
--train_batch_size=1
218218
--gradient_accumulation_steps=1
219-
--max_train_steps=11
219+
--max_train_steps=8
220220
--checkpointing_steps=2
221-
--resume_from_checkpoint=checkpoint-8
222-
--checkpoints_total_limit=3
221+
--resume_from_checkpoint=checkpoint-4
222+
--checkpoints_total_limit=2
223223
""".split()
224224

225225
run_command(self._launch_args + resume_run_args)
226226

227-
self.assertEqual(
228-
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
229-
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
230-
)
227+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})

examples/dreambooth/test_dreambooth_lora.py

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -135,16 +135,13 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_
135135
--resolution=64
136136
--train_batch_size=1
137137
--gradient_accumulation_steps=1
138-
--max_train_steps=9
138+
--max_train_steps=4
139139
--checkpointing_steps=2
140140
""".split()
141141

142142
run_command(self._launch_args + test_args)
143143

144-
self.assertEqual(
145-
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
146-
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
147-
)
144+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-2", "checkpoint-4"})
148145

149146
resume_run_args = f"""
150147
examples/dreambooth/train_dreambooth_lora.py
@@ -155,18 +152,15 @@ def test_dreambooth_lora_checkpointing_checkpoints_total_limit_removes_multiple_
155152
--resolution=64
156153
--train_batch_size=1
157154
--gradient_accumulation_steps=1
158-
--max_train_steps=11
155+
--max_train_steps=8
159156
--checkpointing_steps=2
160-
--resume_from_checkpoint=checkpoint-8
161-
--checkpoints_total_limit=3
157+
--resume_from_checkpoint=checkpoint-4
158+
--checkpoints_total_limit=2
162159
""".split()
163160

164161
run_command(self._launch_args + resume_run_args)
165162

166-
self.assertEqual(
167-
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
168-
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
169-
)
163+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-6", "checkpoint-8"})
170164

171165
def test_dreambooth_lora_if_model(self):
172166
with tempfile.TemporaryDirectory() as tmpdir:
@@ -328,7 +322,7 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
328322
--resolution 64
329323
--train_batch_size 1
330324
--gradient_accumulation_steps 1
331-
--max_train_steps 7
325+
--max_train_steps 6
332326
--checkpointing_steps=2
333327
--checkpoints_total_limit=2
334328
--learning_rate 5.0e-04
@@ -342,14 +336,11 @@ def test_dreambooth_lora_sdxl_checkpointing_checkpoints_total_limit(self):
342336

343337
pipe = DiffusionPipeline.from_pretrained(pipeline_path)
344338
pipe.load_lora_weights(tmpdir)
345-
pipe("a prompt", num_inference_steps=2)
339+
pipe("a prompt", num_inference_steps=1)
346340

347341
# check checkpoint directories exist
348-
self.assertEqual(
349-
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
350-
# checkpoint-2 should have been deleted
351-
{"checkpoint-4", "checkpoint-6"},
352-
)
342+
# checkpoint-2 should have been deleted
343+
self.assertEqual({x for x in os.listdir(tmpdir) if "checkpoint" in x}, {"checkpoint-4", "checkpoint-6"})
353344

354345
def test_dreambooth_lora_sdxl_text_encoder_checkpointing_checkpoints_total_limit(self):
355346
pipeline_path = "hf-internal-testing/tiny-stable-diffusion-xl-pipe"

examples/instruct_pix2pix/test_instruct_pix2pix.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit(self):
4040
--resolution=64
4141
--random_flip
4242
--train_batch_size=1
43-
--max_train_steps=7
43+
--max_train_steps=6
4444
--checkpointing_steps=2
4545
--checkpoints_total_limit=2
4646
--output_dir {tmpdir}
@@ -63,7 +63,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
6363
--resolution=64
6464
--random_flip
6565
--train_batch_size=1
66-
--max_train_steps=9
66+
--max_train_steps=4
6767
--checkpointing_steps=2
6868
--output_dir {tmpdir}
6969
--seed=0
@@ -74,7 +74,7 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
7474
# check checkpoint directories exist
7575
self.assertEqual(
7676
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
77-
{"checkpoint-2", "checkpoint-4", "checkpoint-6", "checkpoint-8"},
77+
{"checkpoint-2", "checkpoint-4"},
7878
)
7979

8080
resume_run_args = f"""
@@ -84,18 +84,18 @@ def test_instruct_pix2pix_checkpointing_checkpoints_total_limit_removes_multiple
8484
--resolution=64
8585
--random_flip
8686
--train_batch_size=1
87-
--max_train_steps=11
87+
--max_train_steps=8
8888
--checkpointing_steps=2
8989
--output_dir {tmpdir}
9090
--seed=0
91-
--resume_from_checkpoint=checkpoint-8
92-
--checkpoints_total_limit=3
91+
--resume_from_checkpoint=checkpoint-4
92+
--checkpoints_total_limit=2
9393
""".split()
9494

9595
run_command(self._launch_args + resume_run_args)
9696

9797
# check checkpoint directories exist
9898
self.assertEqual(
9999
{x for x in os.listdir(tmpdir) if "checkpoint" in x},
100-
{"checkpoint-6", "checkpoint-8", "checkpoint-10"},
100+
{"checkpoint-6", "checkpoint-8"},
101101
)

0 commit comments

Comments
 (0)