1212# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313# See the License for the specific language governing permissions and
1414# limitations under the License.
15- import copy
1615import inspect
1716import os
1817import re
@@ -292,20 +291,6 @@ def _get_modules_to_save(self, pipe, has_denoiser=False):
292291
293292 return modules_to_save
294293
295- def _get_exclude_modules (self , pipe ):
296- from diffusers .utils .peft_utils import _derive_exclude_modules
297-
298- modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
299- denoiser = "unet" if self .unet_kwargs is not None else "transformer"
300- modules_to_save = {k : v for k , v in modules_to_save .items () if k == denoiser }
301- denoiser_lora_state_dict = self ._get_lora_state_dicts (modules_to_save )[f"{ denoiser } _lora_layers" ]
302- pipe .unload_lora_weights ()
303- denoiser_state_dict = pipe .unet .state_dict () if self .unet_kwargs is not None else pipe .transformer .state_dict ()
304- exclude_modules = _derive_exclude_modules (
305- denoiser_state_dict , denoiser_lora_state_dict , adapter_name = "default"
306- )
307- return exclude_modules
308-
309294 def add_adapters_to_pipeline (self , pipe , text_lora_config = None , denoiser_lora_config = None , adapter_name = "default" ):
310295 if text_lora_config is not None :
311296 if "text_encoder" in self .pipeline_class ._lora_loadable_modules :
@@ -2342,58 +2327,6 @@ def test_lora_unload_add_adapter(self):
23422327 )
23432328 _ = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
23442329
2345- @require_peft_version_greater ("0.13.2" )
2346- def test_lora_exclude_modules (self ):
2347- """
2348- Test to check if `exclude_modules` works or not. It works in the following way:
2349- we first create a pipeline and insert LoRA config into it. We then derive a `set`
2350- of modules to exclude by investigating its denoiser state dict and denoiser LoRA
2351- state dict.
2352-
2353- We then create a new LoRA config to include the `exclude_modules` and perform tests.
2354- """
2355- scheduler_cls = self .scheduler_classes [0 ]
2356- components , text_lora_config , denoiser_lora_config = self .get_dummy_components (scheduler_cls )
2357- pipe = self .pipeline_class (** components ).to (torch_device )
2358- _ , _ , inputs = self .get_dummy_inputs (with_generator = False )
2359-
2360- output_no_lora = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2361- self .assertTrue (output_no_lora .shape == self .output_shape )
2362-
2363- # only supported for `denoiser` now
2364- pipe_cp = copy .deepcopy (pipe )
2365- pipe_cp , _ = self .add_adapters_to_pipeline (
2366- pipe_cp , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2367- )
2368- denoiser_exclude_modules = self ._get_exclude_modules (pipe_cp )
2369- pipe_cp .to ("cpu" )
2370- del pipe_cp
2371-
2372- denoiser_lora_config .exclude_modules = denoiser_exclude_modules
2373- pipe , _ = self .add_adapters_to_pipeline (
2374- pipe , text_lora_config = text_lora_config , denoiser_lora_config = denoiser_lora_config
2375- )
2376- output_lora_exclude_modules = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2377-
2378- with tempfile .TemporaryDirectory () as tmpdir :
2379- modules_to_save = self ._get_modules_to_save (pipe , has_denoiser = True )
2380- lora_state_dicts = self ._get_lora_state_dicts (modules_to_save )
2381- lora_metadatas = self ._get_lora_adapter_metadata (modules_to_save )
2382- self .pipeline_class .save_lora_weights (save_directory = tmpdir , ** lora_state_dicts , ** lora_metadatas )
2383- pipe .unload_lora_weights ()
2384- pipe .load_lora_weights (tmpdir )
2385-
2386- output_lora_pretrained = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
2387-
2388- self .assertTrue (
2389- not np .allclose (output_no_lora , output_lora_exclude_modules , atol = 1e-3 , rtol = 1e-3 ),
2390- "LoRA should change outputs." ,
2391- )
2392- self .assertTrue (
2393- np .allclose (output_lora_exclude_modules , output_lora_pretrained , atol = 1e-3 , rtol = 1e-3 ),
2394- "Lora outputs should match." ,
2395- )
2396-
23972330 def test_inference_load_delete_load_adapters (self ):
23982331 "Tests if `load_lora_weights()` -> `delete_adapters()` -> `load_lora_weights()` works."
23992332 for scheduler_cls in self .scheduler_classes :
@@ -2467,7 +2400,6 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
24672400
24682401 components , _ , _ = self .get_dummy_components (self .scheduler_classes [0 ])
24692402 pipe = self .pipeline_class (** components )
2470- pipe = pipe .to (torch_device )
24712403 pipe .set_progress_bar_config (disable = None )
24722404 denoiser = pipe .transformer if self .unet_kwargs is None else pipe .unet
24732405
@@ -2483,6 +2415,10 @@ def _test_group_offloading_inference_denoiser(self, offload_type, use_stream):
24832415 num_blocks_per_group = 1 ,
24842416 use_stream = use_stream ,
24852417 )
2418+ # Place other model-level components on `torch_device`.
2419+ for _ , component in pipe .components .items ():
2420+ if isinstance (component , torch .nn .Module ):
2421+ component .to (torch_device )
24862422 group_offload_hook_1 = _get_top_level_group_offload_hook (denoiser )
24872423 self .assertTrue (group_offload_hook_1 is not None )
24882424 output_1 = pipe (** inputs , generator = torch .manual_seed (0 ))[0 ]
0 commit comments