Skip to content

Commit 40d9b5d

Browse files
authored
[Feature] Add support for LoKR LyCORIS format (#3216)
It's like LoHA but use Kronecker product instead of Hadamard product. https://github.com/KohakuBlueleaf/LyCORIS#lokr I tested it on this 2 LoKR's: https://civitai.com/models/34518/unofficial-vspo-yakumo-beni https://civitai.com/models/35136/mika-pikazo-lokr More tests hard to find as it's new format) Better to test with #3214 Also a bit refactor forward function. //LyCORIS also have (IA)^3 format, but I can't find examples in this format and even on LyCORIS page it's marked as experimental. So, until there some test examples I prefer not to add this.
2 parents 298ccda + da96ec9 commit 40d9b5d

File tree

1 file changed

+102
-14
lines changed

1 file changed

+102
-14
lines changed

ldm/modules/kohya_lora_manager.py

Lines changed: 102 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,13 @@ def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
3131
self.name = name
3232
self.scale = alpha / rank if (alpha and rank) else 1.0
3333

34-
def forward(self, lora, input_h, output):
34+
def forward(self, lora, input_h):
3535
if self.mid is None:
36-
output = (
37-
output
38-
+ self.up(self.down(*input_h)) * lora.multiplier * self.scale
39-
)
36+
weight = self.up(self.down(*input_h))
4037
else:
41-
output = (
42-
output
43-
+ self.up(self.mid(self.down(*input_h))) * lora.multiplier * self.scale
44-
)
45-
return output
38+
weight = self.up(self.mid(self.down(*input_h)))
39+
40+
return weight * lora.multiplier * self.scale
4641

4742
class LoHALayer:
4843
lora_name: str
@@ -64,7 +59,7 @@ def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
6459
self.name = name
6560
self.scale = alpha / rank if (alpha and rank) else 1.0
6661

67-
def forward(self, lora, input_h, output):
62+
def forward(self, lora, input_h):
6863

6964
if type(self.org_module) == torch.nn.Conv2d:
7065
op = torch.nn.functional.conv2d
@@ -86,16 +81,79 @@ def forward(self, lora, input_h, output):
8681
rebuild1 = torch.einsum('i j k l, j r, i p -> p r k l', self.t1, self.w1_b, self.w1_a)
8782
rebuild2 = torch.einsum('i j k l, j r, i p -> p r k l', self.t2, self.w2_b, self.w2_a)
8883
weight = rebuild1 * rebuild2
89-
84+
9085
bias = self.bias if self.bias is not None else 0
91-
return output + op(
86+
return op(
9287
*input_h,
9388
(weight + bias).view(self.org_module.weight.shape),
9489
None,
9590
**extra_args,
9691
) * lora.multiplier * self.scale
9792

9893

94+
class LoKRLayer:
95+
lora_name: str
96+
name: str
97+
scale: float
98+
99+
w1: Optional[torch.Tensor] = None
100+
w1_a: Optional[torch.Tensor] = None
101+
w1_b: Optional[torch.Tensor] = None
102+
w2: Optional[torch.Tensor] = None
103+
w2_a: Optional[torch.Tensor] = None
104+
w2_b: Optional[torch.Tensor] = None
105+
t2: Optional[torch.Tensor] = None
106+
bias: Optional[torch.Tensor] = None
107+
108+
org_module: torch.nn.Module
109+
110+
def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
111+
self.lora_name = lora_name
112+
self.name = name
113+
self.scale = alpha / rank if (alpha and rank) else 1.0
114+
115+
def forward(self, lora, input_h):
116+
117+
if type(self.org_module) == torch.nn.Conv2d:
118+
op = torch.nn.functional.conv2d
119+
extra_args = dict(
120+
stride=self.org_module.stride,
121+
padding=self.org_module.padding,
122+
dilation=self.org_module.dilation,
123+
groups=self.org_module.groups,
124+
)
125+
126+
else:
127+
op = torch.nn.functional.linear
128+
extra_args = {}
129+
130+
w1 = self.w1
131+
if w1 is None:
132+
w1 = self.w1_a @ self.w1_b
133+
134+
w2 = self.w2
135+
if w2 is None:
136+
if self.t2 is None:
137+
w2 = self.w2_a @ self.w2_b
138+
else:
139+
w2 = torch.einsum('i j k l, i p, j r -> p r k l', self.t2, self.w2_a, self.w2_b)
140+
141+
142+
if len(w2.shape) == 4:
143+
w1 = w1.unsqueeze(2).unsqueeze(2)
144+
w2 = w2.contiguous()
145+
weight = torch.kron(w1, w2).reshape(self.org_module.weight.shape)
146+
147+
148+
bias = self.bias if self.bias is not None else 0
149+
return op(
150+
*input_h,
151+
(weight + bias).view(self.org_module.weight.shape),
152+
None,
153+
**extra_args
154+
) * lora.multiplier * self.scale
155+
156+
99157
class LoRAModuleWrapper:
100158
unet: UNet2DConditionModel
101159
text_encoder: CLIPTextModel
@@ -159,7 +217,7 @@ def lora_forward(module, input_h, output):
159217
layer = lora.layers.get(name, None)
160218
if layer is None:
161219
continue
162-
output = layer.forward(lora, input_h, output)
220+
output += layer.forward(lora, input_h)
163221
return output
164222

165223
return lora_forward
@@ -307,6 +365,36 @@ def load_from_dict(self, state_dict):
307365
else:
308366
layer.t2 = None
309367

368+
# lokr
369+
elif "lokr_w1_b" in values or "lokr_w1" in values:
370+
371+
if "lokr_w1_b" in values:
372+
rank = values["lokr_w1_b"].shape[0]
373+
elif "lokr_w2_b" in values:
374+
rank = values["lokr_w2_b"].shape[0]
375+
else:
376+
rank = None # unscaled
377+
378+
layer = LoKRLayer(self.name, stem, rank, alpha)
379+
layer.org_module = wrapped
380+
layer.bias = bias
381+
382+
if "lokr_w1" in values:
383+
layer.w1 = values["lokr_w1"].to(device=self.device, dtype=self.dtype)
384+
else:
385+
layer.w1_a = values["lokr_w1_a"].to(device=self.device, dtype=self.dtype)
386+
layer.w1_b = values["lokr_w1_b"].to(device=self.device, dtype=self.dtype)
387+
388+
if "lokr_w2" in values:
389+
layer.w2 = values["lokr_w2"].to(device=self.device, dtype=self.dtype)
390+
else:
391+
layer.w2_a = values["lokr_w2_a"].to(device=self.device, dtype=self.dtype)
392+
layer.w2_b = values["lokr_w2_b"].to(device=self.device, dtype=self.dtype)
393+
394+
if "lokr_t2" in values:
395+
layer.t2 = values["lokr_t2"].to(device=self.device, dtype=self.dtype)
396+
397+
310398
else:
311399
print(
312400
f">> Encountered unknown lora layer module in {self.name}: {stem} - {type(wrapped).__name__}"

0 commit comments

Comments
 (0)