Skip to content

Commit fc88ac4

Browse files
committed
move private methods to the bottom
1 parent 27fe7c5 commit fc88ac4

File tree

1 file changed

+55
-55
lines changed

1 file changed

+55
-55
lines changed

tests/lora/utils.py

Lines changed: 55 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -251,61 +251,6 @@ def get_dummy_tokens(self):
251251
prepared_inputs["input_ids"] = inputs
252252
return prepared_inputs
253253

254-
def _get_lora_state_dicts(self, modules_to_save):
255-
state_dicts = {}
256-
for module_name, module in modules_to_save.items():
257-
if module is not None:
258-
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
259-
return state_dicts
260-
261-
def _get_lora_adapter_metadata(self, modules_to_save):
262-
metadatas = {}
263-
for module_name, module in modules_to_save.items():
264-
if module is not None:
265-
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
266-
return metadatas
267-
268-
def _get_modules_to_save(self, pipe, has_denoiser=False):
269-
modules_to_save = {}
270-
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
271-
272-
if (
273-
"text_encoder" in lora_loadable_modules
274-
and hasattr(pipe, "text_encoder")
275-
and getattr(pipe.text_encoder, "peft_config", None) is not None
276-
):
277-
modules_to_save["text_encoder"] = pipe.text_encoder
278-
279-
if (
280-
"text_encoder_2" in lora_loadable_modules
281-
and hasattr(pipe, "text_encoder_2")
282-
and getattr(pipe.text_encoder_2, "peft_config", None) is not None
283-
):
284-
modules_to_save["text_encoder_2"] = pipe.text_encoder_2
285-
286-
if has_denoiser:
287-
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
288-
modules_to_save["unet"] = pipe.unet
289-
290-
if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
291-
modules_to_save["transformer"] = pipe.transformer
292-
293-
return modules_to_save
294-
295-
def _get_exclude_modules(self, pipe):
296-
from diffusers.utils.peft_utils import _derive_exclude_modules
297-
298-
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
299-
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
300-
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
301-
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
302-
pipe.unload_lora_weights()
303-
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
304-
exclude_modules = _derive_exclude_modules(
305-
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
306-
)
307-
return exclude_modules
308-
309254
def add_adapters_to_pipeline(self, pipe, text_lora_config=None, denoiser_lora_config=None, adapter_name="default"):
310255
if text_lora_config is not None:
311256
if "text_encoder" in self.pipeline_class._lora_loadable_modules:
@@ -2408,3 +2353,58 @@ def test_group_offloading_inference_denoiser(self, offload_type, use_stream):
24082353
# materializes the test methods on invocation which cannot be overridden.
24092354
return
24102355
self._test_group_offloading_inference_denoiser(offload_type, use_stream)
2356+
2357+
def _get_lora_state_dicts(self, modules_to_save):
2358+
state_dicts = {}
2359+
for module_name, module in modules_to_save.items():
2360+
if module is not None:
2361+
state_dicts[f"{module_name}_lora_layers"] = get_peft_model_state_dict(module)
2362+
return state_dicts
2363+
2364+
def _get_lora_adapter_metadata(self, modules_to_save):
2365+
metadatas = {}
2366+
for module_name, module in modules_to_save.items():
2367+
if module is not None:
2368+
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
2369+
return metadatas
2370+
2371+
def _get_modules_to_save(self, pipe, has_denoiser=False):
2372+
modules_to_save = {}
2373+
lora_loadable_modules = self.pipeline_class._lora_loadable_modules
2374+
2375+
if (
2376+
"text_encoder" in lora_loadable_modules
2377+
and hasattr(pipe, "text_encoder")
2378+
and getattr(pipe.text_encoder, "peft_config", None) is not None
2379+
):
2380+
modules_to_save["text_encoder"] = pipe.text_encoder
2381+
2382+
if (
2383+
"text_encoder_2" in lora_loadable_modules
2384+
and hasattr(pipe, "text_encoder_2")
2385+
and getattr(pipe.text_encoder_2, "peft_config", None) is not None
2386+
):
2387+
modules_to_save["text_encoder_2"] = pipe.text_encoder_2
2388+
2389+
if has_denoiser:
2390+
if "unet" in lora_loadable_modules and hasattr(pipe, "unet"):
2391+
modules_to_save["unet"] = pipe.unet
2392+
2393+
if "transformer" in lora_loadable_modules and hasattr(pipe, "transformer"):
2394+
modules_to_save["transformer"] = pipe.transformer
2395+
2396+
return modules_to_save
2397+
2398+
def _get_exclude_modules(self, pipe):
2399+
from diffusers.utils.peft_utils import _derive_exclude_modules
2400+
2401+
modules_to_save = self._get_modules_to_save(pipe, has_denoiser=True)
2402+
denoiser = "unet" if self.unet_kwargs is not None else "transformer"
2403+
modules_to_save = {k: v for k, v in modules_to_save.items() if k == denoiser}
2404+
denoiser_lora_state_dict = self._get_lora_state_dicts(modules_to_save)[f"{denoiser}_lora_layers"]
2405+
pipe.unload_lora_weights()
2406+
denoiser_state_dict = pipe.unet.state_dict() if self.unet_kwargs is not None else pipe.transformer.state_dict()
2407+
exclude_modules = _derive_exclude_modules(
2408+
denoiser_state_dict, denoiser_lora_state_dict, adapter_name="default"
2409+
)
2410+
return exclude_modules

0 commit comments

Comments
 (0)