Skip to content

Commit ab0bfa7

Browse files
committed
Handle loras in modular denoise
1 parent 7c975f0 commit ab0bfa7

File tree

4 files changed

+227
-4
lines changed

4 files changed

+227
-4
lines changed

invokeai/app/invocations/denoise_latents.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
6161
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
6262
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
63+
from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt
6364
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6465
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6566
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@@ -833,6 +834,16 @@ def step_callback(state: PipelineIntermediateState) -> None:
833834
if self.unet.freeu_config:
834835
ext_manager.add_extension(FreeUExt(self.unet.freeu_config))
835836

837+
### lora
838+
if self.unet.loras:
839+
ext_manager.add_extension(
840+
LoRAPatcherExt(
841+
node_context=context,
842+
loras=self.unet.loras,
843+
prefix="lora_unet_",
844+
)
845+
)
846+
836847
# context for loading additional models
837848
with ExitStack() as exit_stack:
838849
# later should be smth like:

invokeai/backend/lora.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,9 @@ def __init__(
4949
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
5050
raise NotImplementedError()
5151

52+
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
53+
raise NotImplementedError()
54+
5255
def calc_size(self) -> int:
5356
model_size = 0
5457
for val in [self.bias]:
@@ -93,6 +96,9 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
9396

9497
return weight
9598

99+
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
100+
return {"weight": self.get_weight(orig_module.weight)}
101+
96102
def calc_size(self) -> int:
97103
model_size = super().calc_size()
98104
for val in [self.up, self.mid, self.down]:
@@ -149,6 +155,9 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
149155

150156
return weight
151157

158+
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
159+
return {"weight": self.get_weight(orig_module.weight)}
160+
152161
def calc_size(self) -> int:
153162
model_size = super().calc_size()
154163
for val in [self.w1_a, self.w1_b, self.w2_a, self.w2_b, self.t1, self.t2]:
@@ -241,6 +250,9 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
241250

242251
return weight
243252

253+
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
254+
return {"weight": self.get_weight(orig_module.weight)}
255+
244256
def calc_size(self) -> int:
245257
model_size = super().calc_size()
246258
for val in [self.w1, self.w1_a, self.w1_b, self.w2, self.w2_a, self.w2_b, self.t2]:
@@ -293,6 +305,9 @@ def __init__(
293305
def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
294306
return self.weight
295307

308+
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
309+
return {"weight": self.get_weight(orig_module.weight)}
310+
296311
def calc_size(self) -> int:
297312
model_size = super().calc_size()
298313
model_size += self.weight.nelement() * self.weight.element_size()
@@ -327,6 +342,9 @@ def get_weight(self, orig_weight: Optional[torch.Tensor]) -> torch.Tensor:
327342
assert orig_weight is not None
328343
return orig_weight * weight
329344

345+
def get_parameters(self, orig_module: Optional[torch.nn.Module]) -> Dict[str, torch.Tensor]:
346+
return {"weight": self.get_weight(orig_module.weight)}
347+
330348
def calc_size(self) -> int:
331349
model_size = super().calc_size()
332350
model_size += self.weight.nelement() * self.weight.element_size()
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Tuple
5+
6+
import torch
7+
from diffusers import UNet2DConditionModel
8+
9+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
10+
from invokeai.backend.util.devices import TorchDevice
11+
12+
if TYPE_CHECKING:
13+
from invokeai.app.invocations.model import LoRAField
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.lora import LoRAModelRaw
16+
17+
18+
class LoRAPatcherExt(ExtensionBase):
19+
def __init__(
20+
self,
21+
node_context: InvocationContext,
22+
loras: List[LoRAField],
23+
prefix: str,
24+
):
25+
super().__init__()
26+
self._loras = loras
27+
self._prefix = prefix
28+
self._node_context = node_context
29+
30+
@contextmanager
31+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
32+
def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
33+
for lora in self._loras:
34+
lora_info = self._node_context.models.load(lora.lora)
35+
lora_model = lora_info.model
36+
yield (lora_model, lora.weight)
37+
del lora_info
38+
return
39+
40+
yield self._patch_model(
41+
model=unet,
42+
prefix=self._prefix,
43+
loras=_lora_loader(),
44+
cached_weights=cached_weights,
45+
)
46+
47+
@classmethod
48+
@contextmanager
49+
def static_patch_model(
50+
cls,
51+
model: torch.nn.Module,
52+
prefix: str,
53+
loras: Iterator[Tuple[LoRAModelRaw, float]],
54+
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
55+
):
56+
modified_cached_weights, modified_weights = cls._patch_model(
57+
model=model,
58+
prefix=prefix,
59+
loras=loras,
60+
cached_weights=cached_weights,
61+
)
62+
try:
63+
yield
64+
65+
finally:
66+
with torch.no_grad():
67+
for param_key in modified_cached_weights:
68+
model.get_parameter(param_key).copy_(cached_weights[param_key])
69+
for param_key, weight in modified_weights.items():
70+
model.get_parameter(param_key).copy_(weight)
71+
72+
@classmethod
73+
def _patch_model(
74+
cls,
75+
model: UNet2DConditionModel,
76+
prefix: str,
77+
loras: Iterator[Tuple[LoRAModelRaw, float]],
78+
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
79+
):
80+
"""
81+
Apply one or more LoRAs to a model.
82+
:param model: The model to patch.
83+
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
84+
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
85+
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
86+
"""
87+
if cached_weights is None:
88+
cached_weights = {}
89+
90+
modified_weights = {}
91+
modified_cached_weights = set()
92+
with torch.no_grad():
93+
for lora, lora_weight in loras:
94+
# assert lora.device.type == "cpu"
95+
for layer_key, layer in lora.layers.items():
96+
if not layer_key.startswith(prefix):
97+
continue
98+
99+
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
100+
# should be improved in the following ways:
101+
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
102+
# LoRA model is applied.
103+
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
104+
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
105+
# weights to have valid keys.
106+
assert isinstance(model, torch.nn.Module)
107+
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
108+
109+
# All of the LoRA weight calculations will be done on the same device as the module weight.
110+
# (Performance will be best if this is a CUDA device.)
111+
device = module.weight.device
112+
dtype = module.weight.dtype
113+
114+
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
115+
116+
# We intentionally move to the target device first, then cast. Experimentally, this was found to
117+
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
118+
# same thing in a single call to '.to(...)'.
119+
layer.to(device=device)
120+
layer.to(dtype=torch.float32)
121+
122+
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
123+
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
124+
for param_name, lora_param_weight in layer.get_parameters(module).items():
125+
param_key = module_key + "." + param_name
126+
module_param = module.get_parameter(param_name)
127+
128+
# save original weight
129+
if param_key not in modified_cached_weights and param_key not in modified_weights:
130+
if param_key in cached_weights:
131+
modified_cached_weights.add(param_key)
132+
else:
133+
modified_weights[param_key] = module_param.detach().to(
134+
device=TorchDevice.CPU_DEVICE, copy=True
135+
)
136+
137+
if module_param.shape != lora_param_weight.shape:
138+
# TODO: debug on lycoris
139+
lora_param_weight = lora_param_weight.reshape(module_param.shape)
140+
141+
lora_param_weight *= (lora_weight * layer_scale)
142+
module_param += lora_param_weight.to(dtype=dtype)
143+
144+
layer.to(device=TorchDevice.CPU_DEVICE)
145+
146+
return modified_cached_weights, modified_weights
147+
148+
@staticmethod
149+
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
150+
assert "." not in lora_key
151+
152+
if not lora_key.startswith(prefix):
153+
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
154+
155+
module = model
156+
module_key = ""
157+
key_parts = lora_key[len(prefix) :].split("_")
158+
159+
submodule_name = key_parts.pop(0)
160+
161+
while len(key_parts) > 0:
162+
try:
163+
module = module.get_submodule(submodule_name)
164+
module_key += "." + submodule_name
165+
submodule_name = key_parts.pop(0)
166+
except Exception:
167+
submodule_name += "_" + key_parts.pop(0)
168+
169+
module = module.get_submodule(submodule_name)
170+
module_key = (module_key + "." + submodule_name).lstrip(".")
171+
172+
return (module_key, module)

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import ExitStack, contextmanager
4-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
4+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
55

66
import torch
77
from diffusers import UNet2DConditionModel
@@ -67,9 +67,31 @@ def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[s
6767
if self._is_canceled and self._is_canceled():
6868
raise CanceledException
6969

70-
# TODO: create weight patch logic in PR with extension which uses it
71-
with ExitStack() as exit_stack:
70+
modified_weights: Dict[str, torch.Tensor] = {}
71+
modified_cached_weights: Set[str] = set()
72+
73+
exit_stack = ExitStack()
74+
try:
7275
for ext in self._extensions:
73-
exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
76+
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
77+
if res is None:
78+
continue
79+
ext_modified_cached_weights, ext_modified_weights = res
80+
81+
modified_cached_weights.update(ext_modified_cached_weights)
82+
# store only first returned weight for each key, because
83+
# next extension which changes it, will work with already modified weight
84+
for param_key, weight in ext_modified_weights.items():
85+
if param_key in modified_weights:
86+
continue
87+
modified_weights[param_key] = weight
7488

7589
yield None
90+
91+
finally:
92+
exit_stack.close()
93+
with torch.no_grad():
94+
for param_key in modified_cached_weights:
95+
unet.get_parameter(param_key).copy_(cached_weights[param_key])
96+
for param_key, weight in modified_weights.items():
97+
unet.get_parameter(param_key).copy_(weight)

0 commit comments

Comments
 (0)