Skip to content

Commit 9e58256

Browse files
Suggested changes
Co-Authored-By: Ryan Dick <[email protected]>
1 parent faa88f7 commit 9e58256

File tree

5 files changed

+43
-28
lines changed

5 files changed

+43
-28
lines changed

invokeai/backend/lora.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,9 @@ def to(self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype]
7171
self.bias = self.bias.to(device=device, dtype=dtype)
7272

7373
def check_keys(self, values: Dict[str, torch.Tensor], known_keys: Set[str]):
74+
"""Log a warning if values contains unhandled keys."""
75+
# {"alpha", "bias_indices", "bias_values", "bias_size"} are hard-coded, because they are handled by
76+
# `LoRALayerBase`. Sub-classes should provide the known_keys that they handled.
7477
all_known_keys = known_keys | {"alpha", "bias_indices", "bias_values", "bias_size"}
7578
unknown_keys = set(values.keys()) - all_known_keys
7679
if unknown_keys:
@@ -232,7 +235,6 @@ def __init__(
232235
else:
233236
self.rank = None # unscaled
234237

235-
# Although lokr_t1 not used in algo, it still defined in LoKR weights
236238
self.check_keys(
237239
values,
238240
{
@@ -242,7 +244,6 @@ def __init__(
242244
"lokr_w2",
243245
"lokr_w2_a",
244246
"lokr_w2_b",
245-
"lokr_t1",
246247
"lokr_t2",
247248
},
248249
)

invokeai/backend/stable_diffusion/extensions/base.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from contextlib import contextmanager
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Callable, Dict, List, Optional
5+
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple
66

77
import torch
88
from diffusers import UNet2DConditionModel
@@ -56,5 +56,17 @@ def patch_extension(self, ctx: DenoiseContext):
5656
yield None
5757

5858
@contextmanager
59-
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
60-
yield None
59+
def patch_unet(
60+
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
61+
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
62+
"""Apply patches to UNet model. This function responsible for restoring all changes except weights,
63+
changed weights should only be reported in return.
64+
Return contains 2 values:
65+
- Set of cached weights, just keys from cached_weights dictionary
66+
- Dict of not cached weights that should be copies on the cpu device
67+
68+
Args:
69+
unet (UNet2DConditionModel): The UNet model on execution device to patch.
70+
cached_weights (Optional[Dict[str, torch.Tensor]]): Read-only copy of the model's state dict in CPU, for caches purposes.
71+
"""
72+
yield set(), {}

invokeai/backend/stable_diffusion/extensions/freeu.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4-
from typing import TYPE_CHECKING, Dict, Optional
4+
from typing import TYPE_CHECKING, Dict, Optional, Set, Tuple
55

66
import torch
77
from diffusers import UNet2DConditionModel
@@ -21,7 +21,9 @@ def __init__(
2121
self._freeu_config = freeu_config
2222

2323
@contextmanager
24-
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
24+
def patch_unet(
25+
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
26+
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
2527
unet.enable_freeu(
2628
b1=self._freeu_config.b1,
2729
b2=self._freeu_config.b2,
@@ -30,6 +32,6 @@ def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[s
3032
)
3133

3234
try:
33-
yield
35+
yield set(), {}
3436
finally:
3537
unet.disable_freeu()

invokeai/backend/stable_diffusion/extensions/lora.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,9 @@ def __init__(
2828
self._weight = weight
2929

3030
@contextmanager
31-
def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
31+
def patch_unet(
32+
self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[str, torch.Tensor]] = None
33+
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
3234
lora_model = self._node_context.models.load(self._model_id).model
3335
modified_cached_weights, modified_weights = self.patch_model(
3436
model=unet,
@@ -49,14 +51,14 @@ def patch_model(
4951
lora: LoRAModelRaw,
5052
lora_weight: float,
5153
cached_weights: Optional[Dict[str, torch.Tensor]] = None,
52-
):
54+
) -> Tuple[Set[str], Dict[str, torch.Tensor]]:
5355
"""
5456
Apply one or more LoRAs to a model.
5557
:param model: The model to patch.
5658
:param lora: LoRA model to patch in.
5759
:param lora_weight: LoRA patch weight.
5860
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
59-
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
61+
:param cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
6062
"""
6163
if cached_weights is None:
6264
cached_weights = {}

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -70,26 +70,24 @@ def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[s
7070
modified_weights: Dict[str, torch.Tensor] = {}
7171
modified_cached_weights: Set[str] = set()
7272

73-
exit_stack = ExitStack()
7473
try:
75-
for ext in self._extensions:
76-
res = exit_stack.enter_context(ext.patch_unet(unet, cached_weights))
77-
if res is None:
78-
continue
79-
ext_modified_cached_weights, ext_modified_weights = res
80-
81-
modified_cached_weights.update(ext_modified_cached_weights)
82-
# store only first returned weight for each key, because
83-
# next extension which changes it, will work with already modified weight
84-
for param_key, weight in ext_modified_weights.items():
85-
if param_key in modified_weights:
86-
continue
87-
modified_weights[param_key] = weight
88-
89-
yield None
74+
with ExitStack() as exit_stack:
75+
for ext in self._extensions:
76+
ext_modified_cached_weights, ext_modified_weights = exit_stack.enter_context(
77+
ext.patch_unet(unet, cached_weights)
78+
)
79+
80+
modified_cached_weights.update(ext_modified_cached_weights)
81+
# store only first returned weight for each key, because
82+
# next extension which changes it, will work with already modified weight
83+
for param_key, weight in ext_modified_weights.items():
84+
if param_key in modified_weights:
85+
continue
86+
modified_weights[param_key] = weight
87+
88+
yield None
9089

9190
finally:
92-
exit_stack.close()
9391
with torch.no_grad():
9492
for param_key in modified_cached_weights:
9593
unet.get_parameter(param_key).copy_(cached_weights[param_key])

0 commit comments

Comments
 (0)