@@ -31,18 +31,13 @@ def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
3131 self .name = name
3232 self .scale = alpha / rank if (alpha and rank ) else 1.0
3333
34- def forward (self , lora , input_h , output ):
34+ def forward (self , lora , input_h ):
3535 if self .mid is None :
36- output = (
37- output
38- + self .up (self .down (* input_h )) * lora .multiplier * self .scale
39- )
36+ weight = self .up (self .down (* input_h ))
4037 else :
41- output = (
42- output
43- + self .up (self .mid (self .down (* input_h ))) * lora .multiplier * self .scale
44- )
45- return output
38+ weight = self .up (self .mid (self .down (* input_h )))
39+
40+ return weight * lora .multiplier * self .scale
4641
4742class LoHALayer :
4843 lora_name : str
@@ -64,7 +59,7 @@ def __init__(self, lora_name: str, name: str, rank=4, alpha=1.0):
6459 self .name = name
6560 self .scale = alpha / rank if (alpha and rank ) else 1.0
6661
67- def forward (self , lora , input_h , output ):
62+ def forward (self , lora , input_h ):
6863
6964 if type (self .org_module ) == torch .nn .Conv2d :
7065 op = torch .nn .functional .conv2d
@@ -86,16 +81,79 @@ def forward(self, lora, input_h, output):
8681 rebuild1 = torch .einsum ('i j k l, j r, i p -> p r k l' , self .t1 , self .w1_b , self .w1_a )
8782 rebuild2 = torch .einsum ('i j k l, j r, i p -> p r k l' , self .t2 , self .w2_b , self .w2_a )
8883 weight = rebuild1 * rebuild2
89-
84+
9085 bias = self .bias if self .bias is not None else 0
91- return output + op (
86+ return op (
9287 * input_h ,
9388 (weight + bias ).view (self .org_module .weight .shape ),
9489 None ,
9590 ** extra_args ,
9691 ) * lora .multiplier * self .scale
9792
9893
94+ class LoKRLayer :
95+ lora_name : str
96+ name : str
97+ scale : float
98+
99+ w1 : Optional [torch .Tensor ] = None
100+ w1_a : Optional [torch .Tensor ] = None
101+ w1_b : Optional [torch .Tensor ] = None
102+ w2 : Optional [torch .Tensor ] = None
103+ w2_a : Optional [torch .Tensor ] = None
104+ w2_b : Optional [torch .Tensor ] = None
105+ t2 : Optional [torch .Tensor ] = None
106+ bias : Optional [torch .Tensor ] = None
107+
108+ org_module : torch .nn .Module
109+
110+ def __init__ (self , lora_name : str , name : str , rank = 4 , alpha = 1.0 ):
111+ self .lora_name = lora_name
112+ self .name = name
113+ self .scale = alpha / rank if (alpha and rank ) else 1.0
114+
115+ def forward (self , lora , input_h ):
116+
117+ if type (self .org_module ) == torch .nn .Conv2d :
118+ op = torch .nn .functional .conv2d
119+ extra_args = dict (
120+ stride = self .org_module .stride ,
121+ padding = self .org_module .padding ,
122+ dilation = self .org_module .dilation ,
123+ groups = self .org_module .groups ,
124+ )
125+
126+ else :
127+ op = torch .nn .functional .linear
128+ extra_args = {}
129+
130+ w1 = self .w1
131+ if w1 is None :
132+ w1 = self .w1_a @ self .w1_b
133+
134+ w2 = self .w2
135+ if w2 is None :
136+ if self .t2 is None :
137+ w2 = self .w2_a @ self .w2_b
138+ else :
139+ w2 = torch .einsum ('i j k l, i p, j r -> p r k l' , self .t2 , self .w2_a , self .w2_b )
140+
141+
142+ if len (w2 .shape ) == 4 :
143+ w1 = w1 .unsqueeze (2 ).unsqueeze (2 )
144+ w2 = w2 .contiguous ()
145+ weight = torch .kron (w1 , w2 ).reshape (self .org_module .weight .shape )
146+
147+
148+ bias = self .bias if self .bias is not None else 0
149+ return op (
150+ * input_h ,
151+ (weight + bias ).view (self .org_module .weight .shape ),
152+ None ,
153+ ** extra_args
154+ ) * lora .multiplier * self .scale
155+
156+
99157class LoRAModuleWrapper :
100158 unet : UNet2DConditionModel
101159 text_encoder : CLIPTextModel
@@ -159,7 +217,7 @@ def lora_forward(module, input_h, output):
159217 layer = lora .layers .get (name , None )
160218 if layer is None :
161219 continue
162- output = layer .forward (lora , input_h , output )
220+ output + = layer .forward (lora , input_h )
163221 return output
164222
165223 return lora_forward
@@ -307,6 +365,36 @@ def load_from_dict(self, state_dict):
307365 else :
308366 layer .t2 = None
309367
368+ # lokr
369+ elif "lokr_w1_b" in values or "lokr_w1" in values :
370+
371+ if "lokr_w1_b" in values :
372+ rank = values ["lokr_w1_b" ].shape [0 ]
373+ elif "lokr_w2_b" in values :
374+ rank = values ["lokr_w2_b" ].shape [0 ]
375+ else :
376+ rank = None # unscaled
377+
378+ layer = LoKRLayer (self .name , stem , rank , alpha )
379+ layer .org_module = wrapped
380+ layer .bias = bias
381+
382+ if "lokr_w1" in values :
383+ layer .w1 = values ["lokr_w1" ].to (device = self .device , dtype = self .dtype )
384+ else :
385+ layer .w1_a = values ["lokr_w1_a" ].to (device = self .device , dtype = self .dtype )
386+ layer .w1_b = values ["lokr_w1_b" ].to (device = self .device , dtype = self .dtype )
387+
388+ if "lokr_w2" in values :
389+ layer .w2 = values ["lokr_w2" ].to (device = self .device , dtype = self .dtype )
390+ else :
391+ layer .w2_a = values ["lokr_w2_a" ].to (device = self .device , dtype = self .dtype )
392+ layer .w2_b = values ["lokr_w2_b" ].to (device = self .device , dtype = self .dtype )
393+
394+ if "lokr_t2" in values :
395+ layer .t2 = values ["lokr_t2" ].to (device = self .device , dtype = self .dtype )
396+
397+
310398 else :
311399 print (
312400 f">> Encountered unknown lora layer module in { self .name } : { stem } - { type (wrapped ).__name__ } "
0 commit comments