From b49e0ca2d84257e9b4c5e0ace98abc04f62f0e56 Mon Sep 17 00:00:00 2001 From: Thomas-MMJ <112830596+Thomas-MMJ@users.noreply.github.com> Date: Mon, 12 Dec 2022 20:39:24 -0900 Subject: [PATCH] save memory of extract_lora_ups_down by returning a generator instead of a list extract_lora_ups_down currently returns a list, by returning a generator we can save some memory --- lora_diffusion/lora.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/lora_diffusion/lora.py b/lora_diffusion/lora.py index b81ba16..c63312f 100644 --- a/lora_diffusion/lora.py +++ b/lora_diffusion/lora.py @@ -69,18 +69,16 @@ def inject_trainable_lora( def extract_lora_ups_down(model, target_replace_module=["CrossAttention", "Attention"]): - - loras = [] - + no_injection = True + for _module in model.modules(): if _module.__class__.__name__ in target_replace_module: for _child_module in _module.modules(): if _child_module.__class__.__name__ == "LoraInjectedLinear": - loras.append((_child_module.lora_up, _child_module.lora_down)) - if len(loras) == 0: + no_injection = False + yield (_child_module.lora_up, _child_module.lora_down) + if no_injection: raise ValueError("No lora injected.") - return loras - def save_lora_weight( model, path="./lora.pt", target_replace_module=["CrossAttention", "Attention"]