Skip to content

Commit faa88f7

Browse files
committed
Make lora as separate extensions
1 parent 46c632e commit faa88f7

File tree

5 files changed

+192
-247
lines changed

5 files changed

+192
-247
lines changed

invokeai/app/invocations/compel.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,12 +80,12 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
8080

8181
with (
8282
# apply all patches while the model is on the target device
83-
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
83+
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
8484
tokenizer_info as tokenizer,
8585
ModelPatcher.apply_lora_text_encoder(
8686
text_encoder,
8787
loras=_lora_loader(),
88-
model_state_dict=model_state_dict,
88+
cached_weights=cached_weights,
8989
),
9090
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
9191
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
@@ -175,13 +175,13 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
175175

176176
with (
177177
# apply all patches while the model is on the target device
178-
text_encoder_info.model_on_device() as (state_dict, text_encoder),
178+
text_encoder_info.model_on_device() as (cached_weights, text_encoder),
179179
tokenizer_info as tokenizer,
180180
ModelPatcher.apply_lora(
181181
text_encoder,
182182
loras=_lora_loader(),
183183
prefix=lora_prefix,
184-
model_state_dict=state_dict,
184+
cached_weights=cached_weights,
185185
),
186186
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
187187
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),

invokeai/app/invocations/denoise_latents.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@
6060
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
6161
from invokeai.backend.stable_diffusion.extensions.controlnet import ControlNetExt
6262
from invokeai.backend.stable_diffusion.extensions.freeu import FreeUExt
63-
from invokeai.backend.stable_diffusion.extensions.lora_patcher import LoRAPatcherExt
63+
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
6464
from invokeai.backend.stable_diffusion.extensions.preview import PreviewExt
6565
from invokeai.backend.stable_diffusion.extensions.rescale_cfg import RescaleCFGExt
6666
from invokeai.backend.stable_diffusion.extensions_manager import ExtensionsManager
@@ -836,13 +836,14 @@ def step_callback(state: PipelineIntermediateState) -> None:
836836

837837
### lora
838838
if self.unet.loras:
839-
ext_manager.add_extension(
840-
LoRAPatcherExt(
841-
node_context=context,
842-
loras=self.unet.loras,
843-
prefix="lora_unet_",
839+
for lora_field in self.unet.loras:
840+
ext_manager.add_extension(
841+
LoRAExt(
842+
node_context=context,
843+
model_id=lora_field.lora,
844+
weight=lora_field.weight,
845+
)
844846
)
845-
)
846847

847848
# context for loading additional models
848849
with ExitStack() as exit_stack:
@@ -924,14 +925,14 @@ def _lora_loader() -> Iterator[Tuple[LoRAModelRaw, float]]:
924925
assert isinstance(unet_info.model, UNet2DConditionModel)
925926
with (
926927
ExitStack() as exit_stack,
927-
unet_info.model_on_device() as (model_state_dict, unet),
928+
unet_info.model_on_device() as (cached_weights, unet),
928929
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
929930
set_seamless(unet, self.unet.seamless_axes), # FIXME
930931
# Apply the LoRA after unet has been moved to its target device for faster patching.
931932
ModelPatcher.apply_lora_unet(
932933
unet,
933934
loras=_lora_loader(),
934-
model_state_dict=model_state_dict,
935+
cached_weights=cached_weights,
935936
),
936937
):
937938
assert isinstance(unet, UNet2DConditionModel)

invokeai/backend/model_patcher.py

Lines changed: 33 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pickle
77
from contextlib import contextmanager
8-
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
8+
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
99

1010
import numpy as np
1111
import torch
@@ -17,8 +17,8 @@
1717
from invokeai.backend.model_manager import AnyModel
1818
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
1919
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
20+
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
2021
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
21-
from invokeai.backend.util.devices import TorchDevice
2222

2323
"""
2424
loras = [
@@ -85,13 +85,13 @@ def apply_lora_unet(
8585
cls,
8686
unet: UNet2DConditionModel,
8787
loras: Iterator[Tuple[LoRAModelRaw, float]],
88-
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
88+
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
8989
) -> Generator[None, None, None]:
9090
with cls.apply_lora(
9191
unet,
9292
loras=loras,
9393
prefix="lora_unet_",
94-
model_state_dict=model_state_dict,
94+
cached_weights=cached_weights,
9595
):
9696
yield
9797

@@ -101,9 +101,9 @@ def apply_lora_text_encoder(
101101
cls,
102102
text_encoder: CLIPTextModel,
103103
loras: Iterator[Tuple[LoRAModelRaw, float]],
104-
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
104+
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
105105
) -> Generator[None, None, None]:
106-
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
106+
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", cached_weights=cached_weights):
107107
yield
108108

109109
@classmethod
@@ -113,74 +113,45 @@ def apply_lora(
113113
model: AnyModel,
114114
loras: Iterator[Tuple[LoRAModelRaw, float]],
115115
prefix: str,
116-
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
116+
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
117117
) -> Generator[None, None, None]:
118118
"""
119119
Apply one or more LoRAs to a model.
120120
121121
:param model: The model to patch.
122122
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
123123
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
124-
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
124+
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
125125
"""
126-
original_weights = {}
126+
modified_cached_weights: Set[str] = set()
127+
modified_weights: Dict[str, torch.Tensor] = {}
127128
try:
128-
with torch.no_grad():
129-
for lora, lora_weight in loras:
130-
# assert lora.device.type == "cpu"
131-
for layer_key, layer in lora.layers.items():
132-
if not layer_key.startswith(prefix):
133-
continue
134-
135-
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
136-
# should be improved in the following ways:
137-
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
138-
# LoRA model is applied.
139-
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
140-
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
141-
# weights to have valid keys.
142-
assert isinstance(model, torch.nn.Module)
143-
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
144-
145-
# All of the LoRA weight calculations will be done on the same device as the module weight.
146-
# (Performance will be best if this is a CUDA device.)
147-
device = module.weight.device
148-
dtype = module.weight.dtype
149-
150-
if module_key not in original_weights:
151-
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
152-
original_weights[module_key] = model_state_dict[module_key + ".weight"]
153-
else:
154-
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
155-
156-
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
157-
158-
# We intentionally move to the target device first, then cast. Experimentally, this was found to
159-
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
160-
# same thing in a single call to '.to(...)'.
161-
layer.to(device=device)
162-
layer.to(dtype=torch.float32)
163-
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
164-
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
165-
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
166-
layer.to(device=TorchDevice.CPU_DEVICE)
167-
168-
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
169-
if module.weight.shape != layer_weight.shape:
170-
# TODO: debug on lycoris
171-
assert hasattr(layer_weight, "reshape")
172-
layer_weight = layer_weight.reshape(module.weight.shape)
173-
174-
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
175-
module.weight += layer_weight.to(dtype=dtype)
176-
177-
yield # wait for context manager exit
129+
for lora_model, lora_weight in loras:
130+
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
131+
model=model,
132+
prefix=prefix,
133+
lora=lora_model,
134+
lora_weight=lora_weight,
135+
cached_weights=cached_weights,
136+
)
137+
del lora_model
138+
139+
modified_cached_weights.update(lora_modified_cached_weights)
140+
# Store only first returned weight for each key, because
141+
# next extension which changes it, will work with already modified weight
142+
for param_key, weight in lora_modified_weights.items():
143+
if param_key in modified_weights:
144+
continue
145+
modified_weights[param_key] = weight
146+
147+
yield
178148

179149
finally:
180-
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
181150
with torch.no_grad():
182-
for module_key, weight in original_weights.items():
183-
model.get_submodule(module_key).weight.copy_(weight)
151+
for param_key in modified_cached_weights:
152+
model.get_parameter(param_key).copy_(cached_weights[param_key])
153+
for param_key, weight in modified_weights.items():
154+
model.get_parameter(param_key).copy_(weight)
184155

185156
@classmethod
186157
@contextmanager
Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
from __future__ import annotations
2+
3+
from contextlib import contextmanager
4+
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
5+
6+
import torch
7+
from diffusers import UNet2DConditionModel
8+
9+
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
10+
from invokeai.backend.util.devices import TorchDevice
11+
12+
if TYPE_CHECKING:
13+
from invokeai.app.invocations.model import ModelIdentifierField
14+
from invokeai.app.services.shared.invocation_context import InvocationContext
15+
from invokeai.backend.lora import LoRAModelRaw
16+
17+
18+
class LoRAExt(ExtensionBase):
19+
def __init__(
20+
self,
21+
node_context: InvocationContext,
22+
model_id: ModelIdentifierField,
23+
weight: float,
24+
):
25+
super().__init__()
26+
self._node_context = node_context
27+
self._model_id = model_id
28+
self._weight = weight
29+
30+
@contextmanager
31+
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
32+
lora_model = self._node_context.models.load(self._model_id).model
33+
modified_cached_weights, modified_weights = self.patch_model(
34+
model=unet,
35+
prefix="lora_unet_",
36+
lora=lora_model,
37+
lora_weight=self._weight,
38+
cached_weights=cached_weights,
39+
)
40+
del lora_model
41+
42+
yield modified_cached_weights, modified_weights
43+
44+
@classmethod
45+
def patch_model(
46+
cls,
47+
model: torch.nn.Module,
48+
prefix: str,
49+
lora: LoRAModelRaw,
50+
lora_weight: float,
51+
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
52+
):
53+
"""
54+
Apply one or more LoRAs to a model.
55+
:param model: The model to patch.
56+
:param lora: LoRA model to patch in.
57+
:param lora_weight: LoRA patch weight.
58+
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
59+
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
60+
"""
61+
if cached_weights is None:
62+
cached_weights = {}
63+
64+
modified_weights: Dict[str, torch.Tensor] = {}
65+
modified_cached_weights: Set[str] = set()
66+
with torch.no_grad():
67+
# assert lora.device.type == "cpu"
68+
for layer_key, layer in lora.layers.items():
69+
if not layer_key.startswith(prefix):
70+
continue
71+
72+
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
73+
# should be improved in the following ways:
74+
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
75+
# LoRA model is applied.
76+
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
77+
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
78+
# weights to have valid keys.
79+
assert isinstance(model, torch.nn.Module)
80+
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
81+
82+
# All of the LoRA weight calculations will be done on the same device as the module weight.
83+
# (Performance will be best if this is a CUDA device.)
84+
device = module.weight.device
85+
dtype = module.weight.dtype
86+
87+
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
88+
89+
# We intentionally move to the target device first, then cast. Experimentally, this was found to
90+
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
91+
# same thing in a single call to '.to(...)'.
92+
layer.to(device=device)
93+
layer.to(dtype=torch.float32)
94+
95+
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
96+
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
97+
for param_name, lora_param_weight in layer.get_parameters(module).items():
98+
param_key = module_key + "." + param_name
99+
module_param = module.get_parameter(param_name)
100+
101+
# save original weight
102+
if param_key not in modified_cached_weights and param_key not in modified_weights:
103+
if param_key in cached_weights:
104+
modified_cached_weights.add(param_key)
105+
else:
106+
modified_weights[param_key] = module_param.detach().to(
107+
device=TorchDevice.CPU_DEVICE, copy=True
108+
)
109+
110+
if module_param.shape != lora_param_weight.shape:
111+
# TODO: debug on lycoris
112+
lora_param_weight = lora_param_weight.reshape(module_param.shape)
113+
114+
lora_param_weight *= lora_weight * layer_scale
115+
module_param += lora_param_weight.to(dtype=dtype)
116+
117+
layer.to(device=TorchDevice.CPU_DEVICE)
118+
119+
return modified_cached_weights, modified_weights
120+
121+
@staticmethod
122+
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:
123+
assert "." not in lora_key
124+
125+
if not lora_key.startswith(prefix):
126+
raise Exception(f"lora_key with invalid prefix: {lora_key}, {prefix}")
127+
128+
module = model
129+
module_key = ""
130+
key_parts = lora_key[len(prefix) :].split("_")
131+
132+
submodule_name = key_parts.pop(0)
133+
134+
while len(key_parts) > 0:
135+
try:
136+
module = module.get_submodule(submodule_name)
137+
module_key += "." + submodule_name
138+
submodule_name = key_parts.pop(0)
139+
except Exception:
140+
submodule_name += "_" + key_parts.pop(0)
141+
142+
module = module.get_submodule(submodule_name)
143+
module_key = (module_key + "." + submodule_name).lstrip(".")
144+
145+
return (module_key, module)

0 commit comments

Comments
 (0)