11import torch
22from torch import Tensor
3-
43from typing_extensions import Self
54
65
76class LoraConversionKeySet :
87 def __init__ (
9- self ,
10- omi_prefix : str ,
11- diffusers_prefix : str ,
12- legacy_diffusers_prefix : str | None = None ,
13- parent : Self | None = None ,
14- swap_chunks : bool = False ,
15- filter_is_last : bool | None = None ,
16- next_omi_prefix : str | None = None ,
17- next_diffusers_prefix : str | None = None ,
8+ self ,
9+ omi_prefix : str ,
10+ diffusers_prefix : str ,
11+ legacy_diffusers_prefix : str | None = None ,
12+ parent : Self | None = None ,
13+ swap_chunks : bool = False ,
14+ filter_is_last : bool | None = None ,
15+ next_omi_prefix : str | None = None ,
16+ next_diffusers_prefix : str | None = None ,
1817 ):
1918 if parent is not None :
2019 self .omi_prefix = combine (parent .omi_prefix , omi_prefix )
@@ -24,9 +23,11 @@ def __init__(
2423 self .diffusers_prefix = diffusers_prefix
2524
2625 if legacy_diffusers_prefix is None :
27- self .legacy_diffusers_prefix = self .diffusers_prefix .replace ('.' , '_' )
26+ self .legacy_diffusers_prefix = self .diffusers_prefix .replace ("." , "_" )
2827 elif parent is not None :
29- self .legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , legacy_diffusers_prefix ).replace ('.' , '_' )
28+ self .legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , legacy_diffusers_prefix ).replace (
29+ "." , "_"
30+ )
3031 else :
3132 self .legacy_diffusers_prefix = legacy_diffusers_prefix
3233
@@ -42,11 +43,13 @@ def __init__(
4243 elif next_omi_prefix is not None and parent is not None :
4344 self .next_omi_prefix = combine (parent .omi_prefix , next_omi_prefix )
4445 self .next_diffusers_prefix = combine (parent .diffusers_prefix , next_diffusers_prefix )
45- self .next_legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , next_diffusers_prefix ).replace ('.' , '_' )
46+ self .next_legacy_diffusers_prefix = combine (parent .legacy_diffusers_prefix , next_diffusers_prefix ).replace (
47+ "." , "_"
48+ )
4649 elif next_omi_prefix is not None and parent is None :
4750 self .next_omi_prefix = next_omi_prefix
4851 self .next_diffusers_prefix = next_diffusers_prefix
49- self .next_legacy_diffusers_prefix = next_diffusers_prefix .replace ('.' , '_' )
52+ self .next_legacy_diffusers_prefix = next_diffusers_prefix .replace ("." , "_" )
5053 else :
5154 self .next_omi_prefix = None
5255 self .next_diffusers_prefix = None
@@ -61,19 +64,19 @@ def __get_diffusers(self, in_prefix: str, key: str) -> str:
6164 def __get_legacy_diffusers (self , in_prefix : str , key : str ) -> str :
6265 key = self .legacy_diffusers_prefix + key .removeprefix (in_prefix )
6366
64- suffix = key [key .rfind ('.' ) :]
65- if suffix not in [' .alpha' , ' .dora_scale' ]: # some keys only have a single . in the suffix
66- suffix = key [key .removesuffix (suffix ).rfind ('.' ) :]
67+ suffix = key [key .rfind ("." ) :]
68+ if suffix not in [" .alpha" , " .dora_scale" ]: # some keys only have a single . in the suffix
69+ suffix = key [key .removesuffix (suffix ).rfind ("." ) :]
6770 key = key .removesuffix (suffix )
6871
69- return key .replace ('.' , '_' ) + suffix
72+ return key .replace ("." , "_" ) + suffix
7073
7174 def get_key (self , in_prefix : str , key : str , target : str ) -> str :
72- if target == ' omi' :
75+ if target == " omi" :
7376 return self .__get_omi (in_prefix , key )
74- elif target == ' diffusers' :
77+ elif target == " diffusers" :
7578 return self .__get_diffusers (in_prefix , key )
76- elif target == ' legacy_diffusers' :
79+ elif target == " legacy_diffusers" :
7780 return self .__get_legacy_diffusers (in_prefix , key )
7881 return key
7982
@@ -82,8 +85,8 @@ def __str__(self) -> str:
8285
8386
8487def combine (left : str , right : str ) -> str :
85- left = left .rstrip ('.' )
86- right = right .lstrip ('.' )
88+ left = left .rstrip ("." )
89+ right = right .lstrip ("." )
8790 if left == "" or left is None :
8891 return right
8992 elif right == "" or right is None :
@@ -93,25 +96,28 @@ def combine(left: str, right: str) -> str:
9396
9497
9598def map_prefix_range (
96- omi_prefix : str ,
97- diffusers_prefix : str ,
98- parent : LoraConversionKeySet ,
99+ omi_prefix : str ,
100+ diffusers_prefix : str ,
101+ parent : LoraConversionKeySet ,
99102) -> list [LoraConversionKeySet ]:
100103 # 100 should be a safe upper bound. increase if it's not enough in the future
101- return [LoraConversionKeySet (
102- omi_prefix = f"{ omi_prefix } .{ i } " ,
103- diffusers_prefix = f"{ diffusers_prefix } .{ i } " ,
104- parent = parent ,
105- next_omi_prefix = f"{ omi_prefix } .{ i + 1 } " ,
106- next_diffusers_prefix = f"{ diffusers_prefix } .{ i + 1 } " ,
107- ) for i in range (100 )]
104+ return [
105+ LoraConversionKeySet (
106+ omi_prefix = f"{ omi_prefix } .{ i } " ,
107+ diffusers_prefix = f"{ diffusers_prefix } .{ i } " ,
108+ parent = parent ,
109+ next_omi_prefix = f"{ omi_prefix } .{ i + 1 } " ,
110+ next_diffusers_prefix = f"{ diffusers_prefix } .{ i + 1 } " ,
111+ )
112+ for i in range (100 )
113+ ]
108114
109115
110116def __convert (
111- state_dict : dict [str , Tensor ],
112- key_sets : list [LoraConversionKeySet ],
113- source : str ,
114- target : str ,
117+ state_dict : dict [str , Tensor ],
118+ key_sets : list [LoraConversionKeySet ],
119+ source : str ,
120+ target : str ,
115121) -> dict [str , Tensor ]:
116122 out_states = {}
117123
@@ -121,25 +127,25 @@ def __convert(
121127 # TODO: maybe replace with a non O(n^2) algorithm
122128 for key , tensor in state_dict .items ():
123129 for key_set in key_sets :
124- in_prefix = ''
130+ in_prefix = ""
125131
126- if source == ' omi' :
132+ if source == " omi" :
127133 in_prefix = key_set .omi_prefix
128- elif source == ' diffusers' :
134+ elif source == " diffusers" :
129135 in_prefix = key_set .diffusers_prefix
130- elif source == ' legacy_diffusers' :
136+ elif source == " legacy_diffusers" :
131137 in_prefix = key_set .legacy_diffusers_prefix
132138
133139 if not key .startswith (in_prefix ):
134140 continue
135141
136142 if key_set .filter_is_last is not None :
137143 next_prefix = None
138- if source == ' omi' :
144+ if source == " omi" :
139145 next_prefix = key_set .next_omi_prefix
140- elif source == ' diffusers' :
146+ elif source == " diffusers" :
141147 next_prefix = key_set .next_diffusers_prefix
142- elif source == ' legacy_diffusers' :
148+ elif source == " legacy_diffusers" :
143149 next_prefix = key_set .next_legacy_diffusers_prefix
144150
145151 is_last = not any (k .startswith (next_prefix ) for k in state_dict )
@@ -148,8 +154,8 @@ def __convert(
148154
149155 name = key_set .get_key (in_prefix , key , target )
150156
151- can_swap_chunks = target == ' omi' or source == ' omi'
152- if key_set .swap_chunks and name .endswith (' .lora_up.weight' ) and can_swap_chunks :
157+ can_swap_chunks = target == " omi" or source == " omi"
158+ if key_set .swap_chunks and name .endswith (" .lora_up.weight" ) and can_swap_chunks :
153159 chunk_0 , chunk_1 = tensor .chunk (2 , dim = 0 )
154160 tensor = torch .cat ([chunk_1 , chunk_0 ], dim = 0 )
155161
@@ -161,8 +167,8 @@ def __convert(
161167
162168
163169def __detect_source (
164- state_dict : dict [str , Tensor ],
165- key_sets : list [LoraConversionKeySet ],
170+ state_dict : dict [str , Tensor ],
171+ key_sets : list [LoraConversionKeySet ],
166172) -> str :
167173 omi_count = 0
168174 diffusers_count = 0
@@ -178,34 +184,34 @@ def __detect_source(
178184 legacy_diffusers_count += 1
179185
180186 if omi_count > diffusers_count and omi_count > legacy_diffusers_count :
181- return ' omi'
187+ return " omi"
182188 if diffusers_count > omi_count and diffusers_count > legacy_diffusers_count :
183- return ' diffusers'
189+ return " diffusers"
184190 if legacy_diffusers_count > omi_count and legacy_diffusers_count > diffusers_count :
185- return ' legacy_diffusers'
191+ return " legacy_diffusers"
186192
187- return ''
193+ return ""
188194
189195
190196def convert_to_omi (
191- state_dict : dict [str , Tensor ],
192- key_sets : list [LoraConversionKeySet ],
197+ state_dict : dict [str , Tensor ],
198+ key_sets : list [LoraConversionKeySet ],
193199) -> dict [str , Tensor ]:
194200 source = __detect_source (state_dict , key_sets )
195- return __convert (state_dict , key_sets , source , ' omi' )
201+ return __convert (state_dict , key_sets , source , " omi" )
196202
197203
198204def convert_to_diffusers (
199- state_dict : dict [str , Tensor ],
200- key_sets : list [LoraConversionKeySet ],
205+ state_dict : dict [str , Tensor ],
206+ key_sets : list [LoraConversionKeySet ],
201207) -> dict [str , Tensor ]:
202208 source = __detect_source (state_dict , key_sets )
203- return __convert (state_dict , key_sets , source , ' diffusers' )
209+ return __convert (state_dict , key_sets , source , " diffusers" )
204210
205211
206212def convert_to_legacy_diffusers (
207- state_dict : dict [str , Tensor ],
208- key_sets : list [LoraConversionKeySet ],
213+ state_dict : dict [str , Tensor ],
214+ key_sets : list [LoraConversionKeySet ],
209215) -> dict [str , Tensor ]:
210216 source = __detect_source (state_dict , key_sets )
211- return __convert (state_dict , key_sets , source , ' legacy_diffusers' )
217+ return __convert (state_dict , key_sets , source , " legacy_diffusers" )
0 commit comments