@@ -102,7 +102,38 @@ def test_dreambooth_lora_text_encoder_sd3(self):
102102                (key .startswith ("transformer" ) or  key .startswith ("text_encoder" )) for  key  in  lora_state_dict .keys ()
103103            )
104104            self .assertTrue (starts_with_expected_prefix )
105+     def  test_dreambooth_lora_latent_caching (self ):
106+         with  tempfile .TemporaryDirectory () as  tmpdir :
107+             test_args  =  f""" 
108+                 { self .script_path }  
109+                 --pretrained_model_name_or_path { self .pretrained_model_name_or_path }  
110+                 --instance_data_dir { self .instance_data_dir }  
111+                 --instance_prompt { self .instance_prompt }  
112+                 --resolution 64 
113+                 --train_batch_size 1 
114+                 --gradient_accumulation_steps 1 
115+                 --max_train_steps 2 
116+                 --cache_latents 
117+                 --learning_rate 5.0e-04 
118+                 --scale_lr 
119+                 --lr_scheduler constant 
120+                 --lr_warmup_steps 0 
121+                 --output_dir { tmpdir }  
122+                 """ .split ()
123+ 
124+             run_command (self ._launch_args  +  test_args )
125+             # save_pretrained smoke test 
126+             self .assertTrue (os .path .isfile (os .path .join (tmpdir , "pytorch_lora_weights.safetensors" )))
127+ 
128+             # make sure the state_dict has the correct naming in the parameters. 
129+             lora_state_dict  =  safetensors .torch .load_file (os .path .join (tmpdir , "pytorch_lora_weights.safetensors" ))
130+             is_lora  =  all ("lora"  in  k  for  k  in  lora_state_dict .keys ())
131+             self .assertTrue (is_lora )
105132
133+             # when not training the text encoder, all the parameters in the state dict should start 
134+             # with `"transformer"` in their names. 
135+             starts_with_transformer  =  all (key .startswith ("transformer" ) for  key  in  lora_state_dict .keys ())
136+             self .assertTrue (starts_with_transformer )
106137    def  test_dreambooth_lora_sd3_checkpointing_checkpoints_total_limit (self ):
107138        with  tempfile .TemporaryDirectory () as  tmpdir :
108139            test_args  =  f""" 
0 commit comments