@@ -2510,3 +2510,34 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
25102510 # materializes the test methods on invocation which cannot be overridden.
25112511 return
25122512 self ._test_group_offloading_inference_denoiser (offload_type , use_stream )
2513+
2514+ @require_torch_accelerator
2515+ def test_lora_loading_model_cpu_offload (self ):
2516+ components , _ , denoiser_lora_config = self .get_dummy_components (self .scheduler_classes [0 ])
2517+ _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2518+ pipe = self .pipeline_class (** components )
2519+ pipe = pipe .to (torch_device )
2520+ pipe .set_progress_bar_config (disable = None )
2521+
2522+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2523+ denoiser .add_adapter (denoiser_lora_config )
2524+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2525+
2526+ output_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2527+
2528+ with tempfile .TemporaryDirectory () as tmpdirname :
2529+ modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2530+ lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2531+ self .pipeline_class .save_lora_weights (
2532+ save_directory = tmpdirname , safe_serialization = True , ** lora_state_dicts
2533+ )
2534+ # reinitialize the pipeline to mimic the inference workflow.
2535+ components , _ , denoiser_lora_config = self .get_dummy_components (self .scheduler_classes [0 ])
2536+ pipe = self .pipeline_class (** components )
2537+ pipe .enable_model_cpu_offload (device = torch_device )
2538+ pipe .load_lora_weights (tmpdirname )
2539+ denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
2540+ self .assertTrue (check_if_lora_correctly_set (denoiser ), "Lora not correctly set in denoiser." )
2541+
2542+ output_lora_loaded = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2543+ self .assertTrue (np .allclose (output_lora , output_lora_loaded , atol = 1e-3 , rtol = 1e-3 ))
0 commit comments