5
5
6
6
import pickle
7
7
from contextlib import contextmanager
8
- from typing import Any , Dict , Generator , Iterator , List , Optional , Tuple , Type , Union
8
+ from typing import Any , Dict , Generator , Iterator , List , Optional , Set , Tuple , Type , Union
9
9
10
10
import numpy as np
11
11
import torch
17
17
from invokeai .backend .model_manager import AnyModel
18
18
from invokeai .backend .model_manager .load .optimizations import skip_torch_weight_init
19
19
from invokeai .backend .onnx .onnx_runtime import IAIOnnxRuntimeModel
20
+ from invokeai .backend .stable_diffusion .extensions .lora import LoRAExt
20
21
from invokeai .backend .textual_inversion import TextualInversionManager , TextualInversionModelRaw
21
- from invokeai .backend .util .devices import TorchDevice
22
22
23
23
"""
24
24
loras = [
@@ -85,13 +85,13 @@ def apply_lora_unet(
85
85
cls ,
86
86
unet : UNet2DConditionModel ,
87
87
loras : Iterator [Tuple [LoRAModelRaw , float ]],
88
- model_state_dict : Optional [Dict [str , torch .Tensor ]] = None ,
88
+ cached_weights : Optional [Dict [str , torch .Tensor ]] = None ,
89
89
) -> Generator [None , None , None ]:
90
90
with cls .apply_lora (
91
91
unet ,
92
92
loras = loras ,
93
93
prefix = "lora_unet_" ,
94
- model_state_dict = model_state_dict ,
94
+ cached_weights = cached_weights ,
95
95
):
96
96
yield
97
97
@@ -101,9 +101,9 @@ def apply_lora_text_encoder(
101
101
cls ,
102
102
text_encoder : CLIPTextModel ,
103
103
loras : Iterator [Tuple [LoRAModelRaw , float ]],
104
- model_state_dict : Optional [Dict [str , torch .Tensor ]] = None ,
104
+ cached_weights : Optional [Dict [str , torch .Tensor ]] = None ,
105
105
) -> Generator [None , None , None ]:
106
- with cls .apply_lora (text_encoder , loras = loras , prefix = "lora_te_" , model_state_dict = model_state_dict ):
106
+ with cls .apply_lora (text_encoder , loras = loras , prefix = "lora_te_" , cached_weights = cached_weights ):
107
107
yield
108
108
109
109
@classmethod
@@ -113,74 +113,45 @@ def apply_lora(
113
113
model : AnyModel ,
114
114
loras : Iterator [Tuple [LoRAModelRaw , float ]],
115
115
prefix : str ,
116
- model_state_dict : Optional [Dict [str , torch .Tensor ]] = None ,
116
+ cached_weights : Optional [Dict [str , torch .Tensor ]] = None ,
117
117
) -> Generator [None , None , None ]:
118
118
"""
119
119
Apply one or more LoRAs to a model.
120
120
121
121
:param model: The model to patch.
122
122
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
123
123
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
124
- :model_state_dict : Read-only copy of the model's state dict in CPU, for unpatching purposes.
124
+ :cached_weights : Read-only copy of the model's state dict in CPU, for unpatching purposes.
125
125
"""
126
- original_weights = {}
126
+ modified_cached_weights : Set [str ] = set ()
127
+ modified_weights : Dict [str , torch .Tensor ] = {}
127
128
try :
128
- with torch .no_grad ():
129
- for lora , lora_weight in loras :
130
- # assert lora.device.type == "cpu"
131
- for layer_key , layer in lora .layers .items ():
132
- if not layer_key .startswith (prefix ):
133
- continue
134
-
135
- # TODO(ryand): A non-negligible amount of time is currently spent resolving LoRA keys. This
136
- # should be improved in the following ways:
137
- # 1. The key mapping could be more-efficiently pre-computed. This would save time every time a
138
- # LoRA model is applied.
139
- # 2. From an API perspective, there's no reason that the `ModelPatcher` should be aware of the
140
- # intricacies of Stable Diffusion key resolution. It should just expect the input LoRA
141
- # weights to have valid keys.
142
- assert isinstance (model , torch .nn .Module )
143
- module_key , module = cls ._resolve_lora_key (model , layer_key , prefix )
144
-
145
- # All of the LoRA weight calculations will be done on the same device as the module weight.
146
- # (Performance will be best if this is a CUDA device.)
147
- device = module .weight .device
148
- dtype = module .weight .dtype
149
-
150
- if module_key not in original_weights :
151
- if model_state_dict is not None : # we were provided with the CPU copy of the state dict
152
- original_weights [module_key ] = model_state_dict [module_key + ".weight" ]
153
- else :
154
- original_weights [module_key ] = module .weight .detach ().to (device = "cpu" , copy = True )
155
-
156
- layer_scale = layer .alpha / layer .rank if (layer .alpha and layer .rank ) else 1.0
157
-
158
- # We intentionally move to the target device first, then cast. Experimentally, this was found to
159
- # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
160
- # same thing in a single call to '.to(...)'.
161
- layer .to (device = device )
162
- layer .to (dtype = torch .float32 )
163
- # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
164
- # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
165
- layer_weight = layer .get_weight (module .weight ) * (lora_weight * layer_scale )
166
- layer .to (device = TorchDevice .CPU_DEVICE )
167
-
168
- assert isinstance (layer_weight , torch .Tensor ) # mypy thinks layer_weight is a float|Any ??!
169
- if module .weight .shape != layer_weight .shape :
170
- # TODO: debug on lycoris
171
- assert hasattr (layer_weight , "reshape" )
172
- layer_weight = layer_weight .reshape (module .weight .shape )
173
-
174
- assert isinstance (layer_weight , torch .Tensor ) # mypy thinks layer_weight is a float|Any ??!
175
- module .weight += layer_weight .to (dtype = dtype )
176
-
177
- yield # wait for context manager exit
129
+ for lora_model , lora_weight in loras :
130
+ lora_modified_cached_weights , lora_modified_weights = LoRAExt .patch_model (
131
+ model = model ,
132
+ prefix = prefix ,
133
+ lora = lora_model ,
134
+ lora_weight = lora_weight ,
135
+ cached_weights = cached_weights ,
136
+ )
137
+ del lora_model
138
+
139
+ modified_cached_weights .update (lora_modified_cached_weights )
140
+ # Store only first returned weight for each key, because
141
+ # next extension which changes it, will work with already modified weight
142
+ for param_key , weight in lora_modified_weights .items ():
143
+ if param_key in modified_weights :
144
+ continue
145
+ modified_weights [param_key ] = weight
146
+
147
+ yield
178
148
179
149
finally :
180
- assert hasattr (model , "get_submodule" ) # mypy not picking up fact that torch.nn.Module has get_submodule()
181
150
with torch .no_grad ():
182
- for module_key , weight in original_weights .items ():
183
- model .get_submodule (module_key ).weight .copy_ (weight )
151
+ for param_key in modified_cached_weights :
152
+ model .get_parameter (param_key ).copy_ (cached_weights [param_key ])
153
+ for param_key , weight in modified_weights .items ():
154
+ model .get_parameter (param_key ).copy_ (weight )
184
155
185
156
@classmethod
186
157
@contextmanager
0 commit comments