Skip to content

Commit c2e6d98

Browse files
authored
Merge branch 'v2.3' into dev/codeowner-fix-2.3
2 parents 1a704ef + 40d9b5d commit c2e6d98

File tree

3 files changed

+106
-43
lines changed

3 files changed

+106
-43
lines changed

ldm/invoke/config/model_install.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -196,16 +196,6 @@ def create(self):
196196
scroll_exit=True,
197197
)
198198
self.nextrely += 1
199-
self.convert_models = self.add_widget_intelligent(
200-
npyscreen.TitleSelectOne,
201-
name="== CONVERT IMPORTED MODELS INTO DIFFUSERS==",
202-
values=["Keep original format", "Convert to diffusers"],
203-
value=0,
204-
begin_entry_at=4,
205-
max_height=4,
206-
hidden=True, # will appear when imported models box is edited
207-
scroll_exit=True,
208-
)
209199
self.cancel = self.add_widget_intelligent(
210200
npyscreen.ButtonPress,
211201
name="CANCEL",
@@ -240,8 +230,6 @@ def create(self):
240230
self.show_directory_fields.addVisibleWhenSelected(i)
241231

242232
self.show_directory_fields.when_value_edited = self._clear_scan_directory
243-
self.import_model_paths.when_value_edited = self._show_hide_convert
244-
self.autoload_directory.when_value_edited = self._show_hide_convert
245233

246234
def resize(self):
247235
super().resize()
@@ -252,13 +240,6 @@ def _clear_scan_directory(self):
252240
if not self.show_directory_fields.value:
253241
self.autoload_directory.value = ""
254242

255-
def _show_hide_convert(self):
256-
model_paths = self.import_model_paths.value or ""
257-
autoload_directory = self.autoload_directory.value or ""
258-
self.convert_models.hidden = (
259-
len(model_paths) == 0 and len(autoload_directory) == 0
260-
)
261-
262243
def _get_starter_model_labels(self) -> List[str]:
263244
window_width, window_height = get_terminal_size()
264245
label_width = 25
@@ -318,7 +299,6 @@ def marshall_arguments(self):
318299
.scan_directory: Path to a directory of models to scan and import
319300
.autoscan_on_startup: True if invokeai should scan and import at startup time
320301
.import_model_paths: list of URLs, repo_ids and file paths to import
321-
.convert_to_diffusers: if True, convert legacy checkpoints into diffusers
322302
"""
323303
# we're using a global here rather than storing the result in the parentapp
324304
# due to some bug in npyscreen that is causing attributes to be lost
@@ -354,7 +334,6 @@ def marshall_arguments(self):
354334

355335
# URLs and the like
356336
selections.import_model_paths = self.import_model_paths.value.split()
357-
selections.convert_to_diffusers = self.convert_models.value[0] == 1
358337

359338

360339
class AddModelApplication(npyscreen.NPSAppManaged):
@@ -367,7 +346,6 @@ def __init__(self):
367346
scan_directory=None,
368347
autoscan_on_startup=None,
369348
import_model_paths=None,
370-
convert_to_diffusers=None,
371349
)
372350

373351
def onStart(self):
@@ -387,15 +365,13 @@ def process_and_execute(opt: Namespace, selections: Namespace):
387365
directory_to_scan = selections.scan_directory
388366
scan_at_startup = selections.autoscan_on_startup
389367
potential_models_to_install = selections.import_model_paths
390-
convert_to_diffusers = selections.convert_to_diffusers
391368

392369
install_requested_models(
393370
install_initial_models=models_to_install,
394371
remove_models=models_to_remove,
395372
scan_directory=Path(directory_to_scan) if directory_to_scan else None,
396373
external_models=potential_models_to_install,
397374
scan_at_startup=scan_at_startup,
398-
convert_to_diffusers=convert_to_diffusers,
399375
precision="float32"
400376
if opt.full_precision
401377
else choose_precision(torch.device(choose_torch_device())),

ldm/invoke/config/model_install_backend.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,6 @@ def install_requested_models(
6868
scan_directory: Path = None,
6969
external_models: List[str] = None,
7070
scan_at_startup: bool = False,
71-
convert_to_diffusers: bool = False,
7271
precision: str = "float16",
7372
purge_deleted: bool = False,
7473
config_file_path: Path = None,
@@ -111,20 +110,20 @@ def install_requested_models(
111110
if len(external_models)>0:
112111
print("== INSTALLING EXTERNAL MODELS ==")
113112
for path_url_or_repo in external_models:
113+
print(f'DEBUG: path_url_or_repo = {path_url_or_repo}')
114114
try:
115115
model_manager.heuristic_import(
116116
path_url_or_repo,
117-
convert=convert_to_diffusers,
118117
config_file_callback=_pick_configuration_file,
119118
commit_to_conf=config_file_path
120119
)
121120
except KeyboardInterrupt:
122121
sys.exit(-1)
123-
except Exception:
124-
pass
122+
except Exception as e:
123+
print(f'An exception has occurred: {str(e)}')
125124

126125
if scan_at_startup and scan_directory.is_dir():
127-
argument = '--autoconvert' if convert_to_diffusers else '--autoimport'
126+
argument = '--autoconvert'
128127
initfile = Path(Globals.root, Globals.initfile)
129128
replacement = Path(Globals.root, f'{Globals.initfile}.new')
130129
directory = str(scan_directory).replace('\\','/')

ldm/modules/kohya_lora_manager.py

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,13 @@ def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
3131
self.name = name
3232
self.scale = alpha / rank if (alpha and rank) else 1.0
3333

34-
def forward(self, lora, input_h, output):
34+
def forward(self, lora, input_h):
3535
if self.mid is None:
36-
output = (
37-
output
38-
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
39-
)
36+
weight = self.up(self.down(*input_h))
4037
else:
41-
output = (
42-
output
43-
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
44-
)
45-
return output
38+
weight = self.up(self.mid(self.down(*input_h)))
39+
40+
return weight * lora.multiplier * self.scale
4641

4742
class LoHALayer:
4843
lora_name: str
@@ -64,7 +59,7 @@ def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
6459
self.name = name
6560
self.scale = alpha / rank if (alpha and rank) else 1.0
6661

67-
def forward(self, lora, input_h, output):
62+
def forward(self, lora, input_h):
6863

6964
if type(self.org_module) == torch.nn.Conv2d:
7065
op = torch.nn.functional.conv2d
@@ -86,16 +81,79 @@ def forward(self, lora, input_h, output):
8681
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
8782
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
8883
weight = rebuild1 * rebuild2
89-
84+
9085
bias = self.bias if self.bias is not None else 0
91-
return output + op(
86+
return op(
9287
*input_h,
9388
(weight + bias).view(self.org_module.weight.shape),
9489
None,
9590
**extra_args,
9691
) * lora.multiplier * self.scale
9792

9893

94+
class LoKRLayer:
95+
lora_name: str
96+
name: str
97+
scale: float
98+
99+
w1: Optional[torch.Tensor] = None
100+
w1_a: Optional[torch.Tensor] = None
101+
w1_b: Optional[torch.Tensor] = None
102+
w2: Optional[torch.Tensor] = None
103+
w2_a: Optional[torch.Tensor] = None
104+
w2_b: Optional[torch.Tensor] = None
105+
t2: Optional[torch.Tensor] = None
106+
bias: Optional[torch.Tensor] = None
107+
108+
org_module: torch.nn.Module
109+
110+
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
111+
self.lora_name = lora_name
112+
self.name = name
113+
self.scale = alpha / rank if (alpha and rank) else 1.0
114+
115+
def forward(self, lora, input_h):
116+
117+
if type(self.org_module) == torch.nn.Conv2d:
118+
op = torch.nn.functional.conv2d
119+
extra_args = dict(
120+
stride=self.org_module.stride,
121+
padding=self.org_module.padding,
122+
dilation=self.org_module.dilation,
123+
groups=self.org_module.groups,
124+
)
125+
126+
else:
127+
op = torch.nn.functional.linear
128+
extra_args = {}
129+
130+
w1 = self.w1
131+
if w1 is None:
132+
w1 = self.w1_a @ self.w1_b
133+
134+
w2 = self.w2
135+
if w2 is None:
136+
if self.t2 is None:
137+
w2 = self.w2_a @ self.w2_b
138+
else:
139+
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
140+
141+
142+
if len(w2.shape) == 4:
143+
w1 = w1.unsqueeze(2).unsqueeze(2)
144+
w2 = w2.contiguous()
145+
weight = torch.kron(w1, w2).reshape(self.org_module.weight.shape)
146+
147+
148+
bias = self.bias if self.bias is not None else 0
149+
return op(
150+
*input_h,
151+
(weight + bias).view(self.org_module.weight.shape),
152+
None,
153+
**extra_args
154+
) * lora.multiplier * self.scale
155+
156+
99157
class LoRAModuleWrapper:
100158
unet: UNet2DConditionModel
101159
text_encoder: CLIPTextModel
@@ -159,7 +217,7 @@ def lora_forward(module, input_h, output):
159217
layer = lora.layers.get(name, None)
160218
if layer is None:
161219
continue
162-
output = layer.forward(lora, input_h, output)
220+
output += layer.forward(lora, input_h)
163221
return output
164222

165223
return lora_forward
@@ -307,6 +365,36 @@ def load_from_dict(self, state_dict):
307365
else:
308366
layer.t2 = None
309367

368+
# lokr
369+
elif "lokr_w1_b" in values or "lokr_w1" in values:
370+
371+
if "lokr_w1_b" in values:
372+
rank = values["lokr_w1_b"].shape[0]
373+
elif "lokr_w2_b" in values:
374+
rank = values["lokr_w2_b"].shape[0]
375+
else:
376+
rank = None # unscaled
377+
378+
layer = LoKRLayer(self.name, stem, rank, alpha)
379+
layer.org_module = wrapped
380+
layer.bias = bias
381+
382+
if "lokr_w1" in values:
383+
layer.w1 = values["lokr_w1"].to(device=self.device, dtype=self.dtype)
384+
else:
385+
layer.w1_a = values["lokr_w1_a"].to(device=self.device, dtype=self.dtype)
386+
layer.w1_b = values["lokr_w1_b"].to(device=self.device, dtype=self.dtype)
387+
388+
if "lokr_w2" in values:
389+
layer.w2 = values["lokr_w2"].to(device=self.device, dtype=self.dtype)
390+
else:
391+
layer.w2_a = values["lokr_w2_a"].to(device=self.device, dtype=self.dtype)
392+
layer.w2_b = values["lokr_w2_b"].to(device=self.device, dtype=self.dtype)
393+
394+
if "lokr_t2" in values:
395+
layer.t2 = values["lokr_t2"].to(device=self.device, dtype=self.dtype)
396+
397+
310398
else:
311399
print(
312400
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"

0 commit comments

Comments
 (0)