Skip to content

Commit 5c32829

Browse files
committed
add tests for latent caching
1 parent ecccd75 commit 5c32829

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

examples/dreambooth/test_dreambooth_lora_sd3.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,38 @@ def test_dreambooth_lora_text_encoder_sd3(self):
102102
(key.startswith("transformer") or key.startswith("text_encoder")) for key in lora_state_dict.keys()
103103
)
104104
self.assertTrue(starts_with_expected_prefix)
105+
def test_dreambooth_lora_latent_caching(self):
106+
with tempfile.TemporaryDirectory() as tmpdir:
107+
test_args = f"""
108+
{self.script_path}
109+
--pretrained_model_name_or_path {self.pretrained_model_name_or_path}
110+
--instance_data_dir {self.instance_data_dir}
111+
--instance_prompt {self.instance_prompt}
112+
--resolution 64
113+
--train_batch_size 1
114+
--gradient_accumulation_steps 1
115+
--max_train_steps 2
116+
--cache_latents
117+
--learning_rate 5.0e-04
118+
--scale_lr
119+
--lr_scheduler constant
120+
--lr_warmup_steps 0
121+
--output_dir {tmpdir}
122+
""".split()
123+
124+
run_command(self._launch_args + test_args)
125+
# save_pretrained smoke test
126+
self.assertTrue(os.path.isfile(os.path.join(tmpdir, "pytorch_lora_weights.safetensors")))
127+
128+
# make sure the state_dict has the correct naming in the parameters.
129+
lora_state_dict = safetensors.torch.load_file(os.path.join(tmpdir, "pytorch_lora_weights.safetensors"))
130+
is_lora = all("lora" in k for k in lora_state_dict.keys())
131+
self.assertTrue(is_lora)
105132

133+
# when not training the text encoder, all the parameters in the state dict should start
134+
# with `"transformer"` in their names.
135+
starts_with_transformer = all(key.startswith("transformer") for key in lora_state_dict.keys())
136+
self.assertTrue(starts_with_transformer)
106137
def test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit(self):
107138
with tempfile.TemporaryDirectory() as tmpdir:
108139
test_args = f"""

0 commit comments

Comments
 (0)