@@ -127,6 +127,35 @@ class PeftLoraLoaderMixinTests:
127127 text_encoder_target_modules = ["q_proj" , "k_proj" , "v_proj" , "out_proj" ]
128128 denoiser_target_modules = ["to_q" , "to_k" , "to_v" , "to_out.0" ]
129129
130+ def test_simple_inference_save_pretrained_with_text_lora (self ):
131+ """
132+ Tests a simple usecase where users could use saving utilities for text encoder (only)
133+ LoRA through save_pretrained.
134+ """
135+ if not any ("text_encoder" in k for k in self .pipeline_class ._lora_loadable_modules ):
136+ pytest .skip ("Test not supported." )
137+ for scheduler_cls in self .scheduler_classes :
138+ pipe , inputs , _ , text_lora_config , denoiser_lora_config = self ._setup_pipeline_and_get_base_output (
139+ scheduler_cls
140+ )
141+ pipe , _ = self .add_adapters_to_pipeline (pipe , text_lora_config , denoiser_lora_config = None )
142+ images_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
143+
144+ with tempfile .TemporaryDirectory () as tmpdirname :
145+ pipe .save_pretrained (tmpdirname )
146+ pipe_from_pretrained = self .pipeline_class .from_pretrained (tmpdirname )
147+ pipe_from_pretrained .to (torch_device )
148+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = False )
149+
150+ for module_name , module in modules_to_save .items ():
151+ self .assertTrue (check_if_lora_correctly_set (module ), f"Lora not correctly set in { module_name } " )
152+
153+ images_lora_save_pretrained = pipe_from_pretrained (** inputs , generator = torch .manual_seed (0 ))[0 ]
154+ self .assertTrue (
155+ np .allclose (images_lora , images_lora_save_pretrained , atol = 1e-3 , rtol = 1e-3 ),
156+ "Loading from saved checkpoints should give same results." ,
157+ )
158+
130159 def test_low_cpu_mem_usage_with_injection (self ):
131160 """Tests if we can inject LoRA state dict with low_cpu_mem_usage."""
132161 for scheduler_cls in self .scheduler_classes :
0 commit comments