Skip to content

Commit 86f705b

Browse files
committed
Optimize weights handling
1 parent 1fd9631 commit 86f705b

File tree

6 files changed

+62
-27
lines changed

6 files changed

+62
-27
lines changed

invokeai/backend/model_patcher.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from invokeai.backend.onnx.onnx_runtime import IAIOnnxRuntimeModel
2020
from invokeai.backend.stable_diffusion.extensions.lora import LoRAExt
2121
from invokeai.backend.textual_inversion import TextualInversionManager, TextualInversionModelRaw
22+
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
2223

2324
"""
2425
loras = [
@@ -123,9 +124,7 @@ def apply_lora(
123124
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
124125
:cached_weights: Read-only copy of the model's state dict in CPU, for unpatching purposes.
125126
"""
126-
original_weights: Dict[str, torch.Tensor] = {}
127-
if cached_weights:
128-
original_weights.update(cached_weights)
127+
original_weights = OriginalWeightsStorage(cached_weights)
129128
try:
130129
for lora_model, lora_weight in loras:
131130
LoRAExt.patch_model(
@@ -141,7 +140,7 @@ def apply_lora(
141140

142141
finally:
143142
with torch.no_grad():
144-
for param_key, weight in original_weights.items():
143+
for param_key, weight in original_weights.get_changed_weights():
145144
model.get_parameter(param_key).copy_(weight)
146145

147146
@classmethod

invokeai/backend/stable_diffusion/extensions/base.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44
from dataclasses import dataclass
55
from typing import TYPE_CHECKING, Callable, Dict, List
66

7-
import torch
87
from diffusers import UNet2DConditionModel
98

109
if TYPE_CHECKING:
1110
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
1211
from invokeai.backend.stable_diffusion.extension_callback_type import ExtensionCallbackType
12+
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
1313

1414

1515
@dataclass
@@ -56,17 +56,17 @@ def patch_extension(self, ctx: DenoiseContext):
5656
yield None
5757

5858
@contextmanager
59-
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
59+
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
6060
"""A context manager for applying patches to the UNet model. The context manager's lifetime spans the entire
61-
diffusion process. Weight unpatching is handled upstream, and is achieved by adding unsaved weights in
62-
`original_weights` dict. Note that this enables some performance optimization by avoiding redundant operations.
63-
All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched by this
64-
context manager.
61+
diffusion process. Weight unpatching is handled upstream, and is achieved by saving unchanged weights by
62+
`original_weights.save` function. Note that this enables some performance optimization by avoiding redundant
63+
operations. All other patches (e.g. changes to tensor shapes, function monkey-patches, etc.) should be unpatched
64+
by this context manager.
6565
6666
Args:
6767
unet (UNet2DConditionModel): The UNet model on execution device to patch.
68-
original_weights (Dict[str, torch.Tensor]]): A read-only copy of the model's original weights in CPU, for
69-
unpatching purposes. Extension can save tensor which being modified, if it is not saved yet, or can
70-
access original weight value.
68+
original_weights (OriginalWeightsStorage): A storage with copy of the model's original weights in CPU, for
69+
unpatching purposes. Extension should save tensor which being modified in this storage, also extensions
70+
can access original weights values.
7171
"""
7272
yield

invokeai/backend/stable_diffusion/extensions/freeu.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4-
from typing import TYPE_CHECKING, Dict
4+
from typing import TYPE_CHECKING
55

6-
import torch
76
from diffusers import UNet2DConditionModel
87

98
from invokeai.backend.stable_diffusion.extensions.base import ExtensionBase
109

1110
if TYPE_CHECKING:
1211
from invokeai.app.shared.models import FreeUConfig
12+
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
1313

1414

1515
class FreeUExt(ExtensionBase):
@@ -21,7 +21,7 @@ def __init__(
2121
self._freeu_config = freeu_config
2222

2323
@contextmanager
24-
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
24+
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
2525
unet.enable_freeu(
2626
b1=self._freeu_config.b1,
2727
b2=self._freeu_config.b2,

invokeai/backend/stable_diffusion/extensions/lora.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
from contextlib import contextmanager
4-
from typing import TYPE_CHECKING, Dict, Tuple
4+
from typing import TYPE_CHECKING, Tuple
55

66
import torch
77
from diffusers import UNet2DConditionModel
@@ -13,6 +13,7 @@
1313
from invokeai.app.invocations.model import ModelIdentifierField
1414
from invokeai.app.services.shared.invocation_context import InvocationContext
1515
from invokeai.backend.lora import LoRAModelRaw
16+
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
1617

1718

1819
class LoRAExt(ExtensionBase):
@@ -28,7 +29,7 @@ def __init__(
2829
self._weight = weight
2930

3031
@contextmanager
31-
def patch_unet(self, unet: UNet2DConditionModel, original_weights: Dict[str, torch.Tensor]):
32+
def patch_unet(self, unet: UNet2DConditionModel, original_weights: OriginalWeightsStorage):
3233
lora_model = self._node_context.models.load(self._model_id).model
3334
self.patch_model(
3435
model=unet,
@@ -49,17 +50,20 @@ def patch_model(
4950
prefix: str,
5051
lora: LoRAModelRaw,
5152
lora_weight: float,
52-
original_weights: Dict[str, torch.Tensor],
53+
original_weights: OriginalWeightsStorage,
5354
):
5455
"""
5556
Apply one or more LoRAs to a model.
5657
:param model: The model to patch.
5758
:param lora: LoRA model to patch in.
5859
:param lora_weight: LoRA patch weight.
5960
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
60-
:param original_weights: Dict of original weights, filled by weights which lora patches, used for unpatching.
61+
:param original_weights: Storage with original weights, filled by weights which lora patches, used for unpatching.
6162
"""
6263

64+
if lora_weight == 0:
65+
return
66+
6367
# assert lora.device.type == "cpu"
6468
for layer_key, layer in lora.layers.items():
6569
if not layer_key.startswith(prefix):
@@ -95,8 +99,7 @@ def patch_model(
9599
module_param = module.get_parameter(param_name)
96100

97101
# save original weight
98-
if param_key not in original_weights:
99-
original_weights[param_key] = module_param.detach().to(device=TorchDevice.CPU_DEVICE, copy=True)
102+
original_weights.save(param_key, module_param)
100103

101104
if module_param.shape != lora_param_weight.shape:
102105
# TODO: debug on lycoris

invokeai/backend/stable_diffusion/extensions_manager.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from diffusers import UNet2DConditionModel
88

99
from invokeai.app.services.session_processor.session_processor_common import CanceledException
10+
from invokeai.backend.util.original_weights_storage import OriginalWeightsStorage
1011

1112
if TYPE_CHECKING:
1213
from invokeai.backend.stable_diffusion.denoise_context import DenoiseContext
@@ -67,10 +68,7 @@ def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[s
6768
if self._is_canceled and self._is_canceled():
6869
raise CanceledException
6970

70-
original_weights: Dict[str, torch.Tensor] = {}
71-
if cached_weights:
72-
original_weights.update(cached_weights)
73-
71+
original_weights = OriginalWeightsStorage(cached_weights)
7472
try:
7573
with ExitStack() as exit_stack:
7674
for ext in self._extensions:
@@ -80,5 +78,5 @@ def patch_unet(self, unet: UNet2DConditionModel, cached_weights: Optional[Dict[s
8078

8179
finally:
8280
with torch.no_grad():
83-
for param_key, weight in original_weights.items():
81+
for param_key, weight in original_weights.get_changed_weights():
8482
unet.get_parameter(param_key).copy_(weight)
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from __future__ import annotations
2+
3+
from typing import Dict, Iterator, Optional, Tuple
4+
5+
import torch
6+
7+
from invokeai.backend.util.devices import TorchDevice
8+
9+
10+
class OriginalWeightsStorage:
11+
def __init__(self, cached_weights: Optional[Dict[str, torch.Tensor]] = None):
12+
self._weights = {}
13+
self._changed_weights = set()
14+
if cached_weights:
15+
self._weights.update(cached_weights)
16+
17+
def save(self, key: str, weight: torch.Tensor, copy: bool = True):
18+
self._changed_weights.add(key)
19+
if key in self._weights:
20+
return
21+
22+
self._weights[key] = weight.detach().to(device=TorchDevice.CPU_DEVICE, copy=copy)
23+
24+
def get(self, key: str, copy: bool = False) -> Optional[torch.Tensor]:
25+
weight = self._weights.get(key, None)
26+
if weight is not None and copy:
27+
weight = weight.clone()
28+
return weight
29+
30+
def contains(self, key: str) -> bool:
31+
return key in self._weights
32+
33+
def get_changed_weights(self) -> Iterator[Tuple[str, torch.Tensor]]:
34+
for key in self._changed_weights:
35+
yield key, self._weights[key]

0 commit comments

Comments
 (0)