Skip to content

Commit 2227a23

Browse files
Suggested changes + simplify weights logic in patching
Co-Authored-By: Ryan Dick <[email protected]>
1 parent 8500bac commit 2227a23

File tree

6 files changed

+86
-118
lines changed

6 files changed

+86
-118
lines changed

invokeai/backend/lora.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -490,6 +490,9 @@ def from_checkpoint(
490490
state_dict = cls._convert_sdxl_keys_to_diffusers_format(state_dict)
491491

492492
for layer_key, values in state_dict.items():
493+
# Detect layers according to LyCORIS detection logic(`weight_list_det`)
494+
# https://github.com/KohakuBlueleaf/LyCORIS/tree/8ad8000efb79e2b879054da8c9356e6143591bad/lycoris/modules
495+
493496
# lora and locon
494497
if "lora_up.weight" in values:
495498
layer: AnyLoRALayer = LoRALayer(layer_key, values)

invokeai/backend/model_patcher.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
import pickle
77
from contextlib import contextmanager
8-
from typing import Any, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union
8+
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Type, Union
99

1010
import numpy as np
1111
import torch
@@ -123,34 +123,25 @@ def apply_lora(
123123
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
124124
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
125125
"""
126-
modified_cached_weights: Set[str] = set()
127-
modified_weights: Dict[str, torch.Tensor] = {}
126+
original_weights: Dict[str, torch.Tensor] = {}
127+
if cached_weights:
128+
original_weights.update(cached_weights)
128129
try:
129130
for lora_model, lora_weight in loras:
130-
lora_modified_cached_weights, lora_modified_weights = LoRAExt.patch_model(
131+
LoRAExt.patch_model(
131132
model=model,
132133
prefix=prefix,
133134
lora=lora_model,
134135
lora_weight=lora_weight,
135-
cached_weights=cached_weights,
136+
original_weights=original_weights,
136137
)
137138
del lora_model
138139

139-
modified_cached_weights.update(lora_modified_cached_weights)
140-
# Store only first returned weight for each key, because
141-
# next extension which changes it, will work with already modified weight
142-
for param_key, weight in lora_modified_weights.items():
143-
if param_key in modified_weights:
144-
continue
145-
modified_weights[param_key] = weight
146-
147140
yield
148141

149142
finally:
150143
with torch.no_grad():
151-
for param_key in modified_cached_weights:
152-
model.get_parameter(param_key).copy_(cached_weights[param_key])
153-
for param_key, weight in modified_weights.items():
144+
for param_key, weight in original_weights.items():
154145
model.get_parameter(param_key).copy_(weight)
155146

156147
@classmethod

invokeai/backend/stable_diffusion/extensions/base.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
5+
from typing import TYPE_CHECKING, Callable, Dict, List
66

77
import torch
88
from diffusers import UNet2DConditionModel
@@ -56,17 +56,17 @@ def patch_extension(self, ctx: DenoiseContext):
5656
yield None
5757

5858
@contextmanager
59-
def patch_unet(
60-
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
61-
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
62-
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
63-
changed weights should only be reported in return.
64-
Return contains 2 values:
65-
- Set of cached weights, just keys from cached_weights dictionary
66-
- Dict of not cached weights that should be copies on the cpu device
59+
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
60+
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
61+
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
62+
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
63+
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
64+
context manager.
6765
6866
Args:
6967
unet (UNet2DConditionModel): The UNet model on execution device to patch.
70-
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
68+
cached_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
69+
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
70+
access original weight value.
7171
"""
72-
yield set(), {}
72+
yield

invokeai/backend/stable_diffusion/extensions/freeu.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4-
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
4+
from typing import TYPE_CHECKING, Dict
55

66
import torch
77
from diffusers import UNet2DConditionModel
@@ -21,9 +21,7 @@ def __init__(
2121
self._freeu_config = freeu_config
2222

2323
@contextmanager
24-
def patch_unet(
25-
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
26-
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
24+
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
2725
unet.enable_freeu(
2826
b1=self._freeu_config.b1,
2927
b2=self._freeu_config.b2,
@@ -32,6 +30,6 @@ def patch_unet(
3230
)
3331

3432
try:
35-
yield set(), {}
33+
yield
3634
finally:
3735
unet.disable_freeu()

invokeai/backend/stable_diffusion/extensions/lora.py

Lines changed: 56 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4-
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
4+
from typing import TYPE_CHECKING, Dict, Tuple
55

66
import torch
77
from diffusers import UNet2DConditionModel
@@ -28,97 +28,84 @@ def __init__(
2828
self._weight = weight
2929

3030
@contextmanager
31-
def patch_unet(
32-
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
33-
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
31+
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
3432
lora_model = self._node_context.models.load(self._model_id).model
35-
modified_cached_weights, modified_weights = self.patch_model(
33+
self.patch_model(
3634
model=unet,
3735
prefix="lora_unet_",
3836
lora=lora_model,
3937
lora_weight=self._weight,
40-
cached_weights=cached_weights,
38+
original_weights=original_weights,
4139
)
4240
del lora_model
4341

44-
yield modified_cached_weights, modified_weights
42+
yield
4543

4644
@classmethod
45+
@torch.no_grad()
4746
def patch_model(
4847
cls,
4948
model: torch.nn.Module,
5049
prefix: str,
5150
lora: LoRAModelRaw,
5251
lora_weight: float,
53-
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
54-
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
52+
original_weights: Dict[str, torch.Tensor],
53+
):
5554
"""
5655
Apply one or more LoRAs to a model.
5756
:param model: The model to patch.
5857
:param lora: LoRA model to patch in.
5958
:param lora_weight: LoRA patch weight.
6059
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
61-
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
60+
:param original_weights: TODO:
6261
"""
63-
if cached_weights is None:
64-
cached_weights = {}
65-
66-
modified_weights: Dict[str, torch.Tensor] = {}
67-
modified_cached_weights: Set[str] = set()
68-
with torch.no_grad():
69-
# assert lora.device.type == "cpu"
70-
for layer_key, layer in lora.layers.items():
71-
if not layer_key.startswith(prefix):
72-
continue
73-
74-
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
75-
# should be improved in the following ways:
76-
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
77-
# LoRA model is applied.
78-
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
79-
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
80-
# weights to have valid keys.
81-
assert isinstance(model, torch.nn.Module)
82-
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
83-
84-
# All of the LoRA weight calculations will be done on the same device as the module weight.
85-
# (Performance will be best if this is a CUDA device.)
86-
device = module.weight.device
87-
dtype = module.weight.dtype
88-
89-
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
90-
91-
# We intentionally move to the target device first, then cast. Experimentally, this was found to
92-
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
93-
# same thing in a single call to '.to(...)'.
94-
layer.to(device=device)
95-
layer.to(dtype=torch.float32)
96-
97-
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
98-
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
99-
for param_name, lora_param_weight in layer.get_parameters(module).items():
100-
param_key = module_key + "." + param_name
101-
module_param = module.get_parameter(param_name)
102-
103-
# save original weight
104-
if param_key not in modified_cached_weights and param_key not in modified_weights:
105-
if param_key in cached_weights:
106-
modified_cached_weights.add(param_key)
107-
else:
108-
modified_weights[param_key] = module_param.detach().to(
109-
device=TorchDevice.CPU_DEVICE, copy=True
110-
)
111-
112-
if module_param.shape != lora_param_weight.shape:
113-
# TODO: debug on lycoris
114-
lora_param_weight = lora_param_weight.reshape(module_param.shape)
115-
116-
lora_param_weight *= lora_weight * layer_scale
117-
module_param += lora_param_weight.to(dtype=dtype)
118-
119-
layer.to(device=TorchDevice.CPU_DEVICE)
120-
121-
return modified_cached_weights, modified_weights
62+
63+
# assert lora.device.type == "cpu"
64+
for layer_key, layer in lora.layers.items():
65+
if not layer_key.startswith(prefix):
66+
continue
67+
68+
# TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
69+
# should be improved in the following ways:
70+
# 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
71+
# LoRA model is applied.
72+
# 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
73+
# intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
74+
# weights to have valid keys.
75+
assert isinstance(model, torch.nn.Module)
76+
module_key, module = cls._resolve_lora_key(model, layer_key, prefix)
77+
78+
# All of the LoRA weight calculations will be done on the same device as the module weight.
79+
# (Performance will be best if this is a CUDA device.)
80+
device = module.weight.device
81+
dtype = module.weight.dtype
82+
83+
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
84+
85+
# We intentionally move to the target device first, then cast. Experimentally, this was found to
86+
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
87+
# same thing in a single call to '.to(...)'.
88+
layer.to(device=device)
89+
layer.to(dtype=torch.float32)
90+
91+
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
92+
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
93+
for param_name, lora_param_weight in layer.get_parameters(module).items():
94+
param_key = module_key + "." + param_name
95+
module_param = module.get_parameter(param_name)
96+
97+
# save original weight
98+
if param_key not in original_weights:
99+
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
100+
101+
if module_param.shape != lora_param_weight.shape:
102+
# TODO: debug on lycoris
103+
lora_param_weight = lora_param_weight.reshape(module_param.shape)
104+
105+
lora_param_weight *= lora_weight * layer_scale
106+
module_param += lora_param_weight.to(dtype=dtype)
107+
108+
layer.to(device=TorchDevice.CPU_DEVICE)
122109

123110
@staticmethod
124111
def _resolve_lora_key(model: torch.nn.Module, lora_key: str, prefix: str) -> Tuple[str, torch.nn.Module]:

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import ExitStack, contextmanager
4-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set
4+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
55

66
import torch
77
from diffusers import UNet2DConditionModel
@@ -67,29 +67,18 @@ def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[s
6767
if self._is_canceled and self._is_canceled():
6868
raise CanceledException
6969

70-
modified_weights: Dict[str, torch.Tensor] = {}
71-
modified_cached_weights: Set[str] = set()
70+
original_weights: Dict[str, torch.Tensor] = {}
71+
if cached_weights:
72+
original_weights.update(cached_weights)
7273

7374
try:
7475
with ExitStack() as exit_stack:
7576
for ext in self._extensions:
76-
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
77-
ext.patch_unet(unet, cached_weights)
78-
)
79-
80-
modified_cached_weights.update(ext_modified_cached_weights)
81-
# store only first returned weight for each key, because
82-
# next extension which changes it, will work with already modified weight
83-
for param_key, weight in ext_modified_weights.items():
84-
if param_key in modified_weights:
85-
continue
86-
modified_weights[param_key] = weight
77+
exit_stack.enter_context(ext.patch_unet(unet, original_weights))
8778

8879
yield None
8980

9081
finally:
9182
with torch.no_grad():
92-
for param_key in modified_cached_weights:
93-
unet.get_parameter(param_key).copy_(cached_weights[param_key])
94-
for param_key, weight in modified_weights.items():
83+
for param_key, weight in original_weights.items():
9584
unet.get_parameter(param_key).copy_(weight)

0 commit comments

Comments
 (0)