|
12 | 12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | 13 | # See the License for the specific language governing permissions and |
14 | 14 | # limitations under the License. |
15 | | -"""Conversion script for the Stable Diffusion checkpoints.""" |
| 15 | + |
| 16 | +""" |
| 17 | +Conversion scripts for the various modeling checkpoints. These scripts convert original model implementations to |
| 18 | +Diffusers adapted versions. This usually only involves renaming/remapping the state dict keys and changing some |
| 19 | +modeling components partially (for example, splitting a single QKV linear to individual Q, K, V layers). |
| 20 | +""" |
16 | 21 |
|
17 | 22 | import copy |
18 | 23 | import os |
|
92 | 97 | "double_blocks.0.img_attn.norm.key_norm.scale", |
93 | 98 | "model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale", |
94 | 99 | ], |
| 100 | + "autoencoder_dc": "decoder.stages.0.op_list.0.main.conv.conv.weight", |
95 | 101 | } |
96 | 102 |
|
97 | 103 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = { |
@@ -2198,3 +2204,251 @@ def swap_scale_shift(weight): |
2198 | 2204 | ) |
2199 | 2205 |
|
2200 | 2206 | return converted_state_dict |
| 2207 | + |
| 2208 | + |
| 2209 | +def create_autoencoder_dc_config_from_original(original_config, checkpoint, **kwargs): |
| 2210 | + model_name = original_config.get("model_name", "dc-ae-f32c32-sana-1.0") |
| 2211 | + print("trying:", model_name) |
| 2212 | + |
| 2213 | + if model_name in ["dc-ae-f32c32-sana-1.0"]: |
| 2214 | + config = { |
| 2215 | + "latent_channels": 32, |
| 2216 | + "encoder_block_types": ( |
| 2217 | + "ResBlock", |
| 2218 | + "ResBlock", |
| 2219 | + "ResBlock", |
| 2220 | + "EfficientViTBlock", |
| 2221 | + "EfficientViTBlock", |
| 2222 | + "EfficientViTBlock", |
| 2223 | + ), |
| 2224 | + "decoder_block_types": ( |
| 2225 | + "ResBlock", |
| 2226 | + "ResBlock", |
| 2227 | + "ResBlock", |
| 2228 | + "EfficientViTBlock", |
| 2229 | + "EfficientViTBlock", |
| 2230 | + "EfficientViTBlock", |
| 2231 | + ), |
| 2232 | + "encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
| 2233 | + "decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024), |
| 2234 | + "encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
| 2235 | + "decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)), |
| 2236 | + "encoder_layers_per_block": (2, 2, 2, 3, 3, 3), |
| 2237 | + "decoder_layers_per_block": [3, 3, 3, 3, 3, 3], |
| 2238 | + "downsample_block_type": "conv", |
| 2239 | + "upsample_block_type": "interpolate", |
| 2240 | + "decoder_norm_types": "rms_norm", |
| 2241 | + "decoder_act_fns": "silu", |
| 2242 | + "scaling_factor": 0.41407, |
| 2243 | + } |
| 2244 | + elif model_name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]: |
| 2245 | + config = { |
| 2246 | + "latent_channels": 32, |
| 2247 | + "encoder_block_types": [ |
| 2248 | + "ResBlock", |
| 2249 | + "ResBlock", |
| 2250 | + "ResBlock", |
| 2251 | + "EfficientViTBlock", |
| 2252 | + "EfficientViTBlock", |
| 2253 | + "EfficientViTBlock", |
| 2254 | + ], |
| 2255 | + "decoder_block_types": [ |
| 2256 | + "ResBlock", |
| 2257 | + "ResBlock", |
| 2258 | + "ResBlock", |
| 2259 | + "EfficientViTBlock", |
| 2260 | + "EfficientViTBlock", |
| 2261 | + "EfficientViTBlock", |
| 2262 | + ], |
| 2263 | + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| 2264 | + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024], |
| 2265 | + "encoder_layers_per_block": [0, 4, 8, 2, 2, 2], |
| 2266 | + "decoder_layers_per_block": [0, 5, 10, 2, 2, 2], |
| 2267 | + "encoder_qkv_multiscales": ((), (), (), (), (), ()), |
| 2268 | + "decoder_qkv_multiscales": ((), (), (), (), (), ()), |
| 2269 | + "decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"], |
| 2270 | + "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"], |
| 2271 | + } |
| 2272 | + if model_name == "dc-ae-f32c32-in-1.0": |
| 2273 | + config["scaling_factor"] = 0.3189 |
| 2274 | + elif model_name == "dc-ae-f32c32-mix-1.0": |
| 2275 | + config["scaling_factor"] = 0.4552 |
| 2276 | + elif model_name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]: |
| 2277 | + config = { |
| 2278 | + "latent_channels": 128, |
| 2279 | + "encoder_block_types": [ |
| 2280 | + "ResBlock", |
| 2281 | + "ResBlock", |
| 2282 | + "ResBlock", |
| 2283 | + "EfficientViTBlock", |
| 2284 | + "EfficientViTBlock", |
| 2285 | + "EfficientViTBlock", |
| 2286 | + "EfficientViTBlock", |
| 2287 | + ], |
| 2288 | + "decoder_block_types": [ |
| 2289 | + "ResBlock", |
| 2290 | + "ResBlock", |
| 2291 | + "ResBlock", |
| 2292 | + "EfficientViTBlock", |
| 2293 | + "EfficientViTBlock", |
| 2294 | + "EfficientViTBlock", |
| 2295 | + "EfficientViTBlock", |
| 2296 | + ], |
| 2297 | + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
| 2298 | + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048], |
| 2299 | + "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2], |
| 2300 | + "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2], |
| 2301 | + "encoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
| 2302 | + "decoder_qkv_multiscales": ((), (), (), (), (), (), ()), |
| 2303 | + "decoder_norm_types": [ |
| 2304 | + "batch_norm", |
| 2305 | + "batch_norm", |
| 2306 | + "batch_norm", |
| 2307 | + "rms_norm", |
| 2308 | + "rms_norm", |
| 2309 | + "rms_norm", |
| 2310 | + "rms_norm", |
| 2311 | + ], |
| 2312 | + "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"], |
| 2313 | + } |
| 2314 | + if model_name == "dc-ae-f64c128-in-1.0": |
| 2315 | + config["scaling_factor"] = 0.2889 |
| 2316 | + elif model_name == "dc-ae-f64c128-mix-1.0": |
| 2317 | + config["scaling_factor"] = 0.4538 |
| 2318 | + elif model_name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]: |
| 2319 | + config = { |
| 2320 | + "latent_channels": 512, |
| 2321 | + "encoder_block_types": [ |
| 2322 | + "ResBlock", |
| 2323 | + "ResBlock", |
| 2324 | + "ResBlock", |
| 2325 | + "EfficientViTBlock", |
| 2326 | + "EfficientViTBlock", |
| 2327 | + "EfficientViTBlock", |
| 2328 | + "EfficientViTBlock", |
| 2329 | + "EfficientViTBlock", |
| 2330 | + ], |
| 2331 | + "decoder_block_types": [ |
| 2332 | + "ResBlock", |
| 2333 | + "ResBlock", |
| 2334 | + "ResBlock", |
| 2335 | + "EfficientViTBlock", |
| 2336 | + "EfficientViTBlock", |
| 2337 | + "EfficientViTBlock", |
| 2338 | + "EfficientViTBlock", |
| 2339 | + "EfficientViTBlock", |
| 2340 | + ], |
| 2341 | + "encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
| 2342 | + "decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048], |
| 2343 | + "encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2], |
| 2344 | + "decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2], |
| 2345 | + "encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
| 2346 | + "decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()), |
| 2347 | + "decoder_norm_types": [ |
| 2348 | + "batch_norm", |
| 2349 | + "batch_norm", |
| 2350 | + "batch_norm", |
| 2351 | + "rms_norm", |
| 2352 | + "rms_norm", |
| 2353 | + "rms_norm", |
| 2354 | + "rms_norm", |
| 2355 | + "rms_norm", |
| 2356 | + ], |
| 2357 | + "decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"], |
| 2358 | + } |
| 2359 | + if model_name == "dc-ae-f128c512-in-1.0": |
| 2360 | + config["scaling_factor"] = 0.4883 |
| 2361 | + elif model_name == "dc-ae-f128c512-mix-1.0": |
| 2362 | + config["scaling_factor"] = 0.3620 |
| 2363 | + |
| 2364 | + config.update({"model_name": model_name}) |
| 2365 | + |
| 2366 | + return config |
| 2367 | + |
| 2368 | + |
| 2369 | +def convert_autoencoder_dc_checkpoint_to_diffusers(config, checkpoint, **kwargs): |
| 2370 | + converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())} |
| 2371 | + model_name = config.pop("model_name") |
| 2372 | + |
| 2373 | + def remap_qkv_(key: str, state_dict): |
| 2374 | + qkv = state_dict.pop(key) |
| 2375 | + q, k, v = torch.chunk(qkv, 3, dim=0) |
| 2376 | + parent_module, _, _ = key.rpartition(".qkv.conv.weight") |
| 2377 | + state_dict[f"{parent_module}.to_q.weight"] = q.squeeze() |
| 2378 | + state_dict[f"{parent_module}.to_k.weight"] = k.squeeze() |
| 2379 | + state_dict[f"{parent_module}.to_v.weight"] = v.squeeze() |
| 2380 | + |
| 2381 | + def remap_proj_conv_(key: str, state_dict): |
| 2382 | + parent_module, _, _ = key.rpartition(".proj.conv.weight") |
| 2383 | + state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze() |
| 2384 | + |
| 2385 | + AE_KEYS_RENAME_DICT = { |
| 2386 | + # common |
| 2387 | + "main.": "", |
| 2388 | + "op_list.": "", |
| 2389 | + "context_module": "attn", |
| 2390 | + "local_module": "conv_out", |
| 2391 | + # NOTE: The below two lines work because scales in the available configs only have a tuple length of 1 |
| 2392 | + # If there were more scales, there would be more layers, so a loop would be better to handle this |
| 2393 | + "aggreg.0.0": "to_qkv_multiscale.0.proj_in", |
| 2394 | + "aggreg.0.1": "to_qkv_multiscale.0.proj_out", |
| 2395 | + "depth_conv.conv": "conv_depth", |
| 2396 | + "inverted_conv.conv": "conv_inverted", |
| 2397 | + "point_conv.conv": "conv_point", |
| 2398 | + "point_conv.norm": "norm", |
| 2399 | + "conv.conv.": "conv.", |
| 2400 | + "conv1.conv": "conv1", |
| 2401 | + "conv2.conv": "conv2", |
| 2402 | + "conv2.norm": "norm", |
| 2403 | + "proj.norm": "norm_out", |
| 2404 | + # encoder |
| 2405 | + "encoder.project_in.conv": "encoder.conv_in", |
| 2406 | + "encoder.project_out.0.conv": "encoder.conv_out", |
| 2407 | + "encoder.stages": "encoder.down_blocks", |
| 2408 | + # decoder |
| 2409 | + "decoder.project_in.conv": "decoder.conv_in", |
| 2410 | + "decoder.project_out.0": "decoder.norm_out", |
| 2411 | + "decoder.project_out.2.conv": "decoder.conv_out", |
| 2412 | + "decoder.stages": "decoder.up_blocks", |
| 2413 | + } |
| 2414 | + |
| 2415 | + AE_F32C32_KEYS = { |
| 2416 | + "encoder.project_in.conv": "encoder.conv_in.conv", |
| 2417 | + "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| 2418 | + } |
| 2419 | + |
| 2420 | + AE_F64C128_KEYS = { |
| 2421 | + "encoder.project_in.conv": "encoder.conv_in.conv", |
| 2422 | + "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| 2423 | + } |
| 2424 | + |
| 2425 | + AE_F128C512_KEYS = { |
| 2426 | + "encoder.project_in.conv": "encoder.conv_in.conv", |
| 2427 | + "decoder.project_out.2.conv": "decoder.conv_out.conv", |
| 2428 | + } |
| 2429 | + |
| 2430 | + AE_SPECIAL_KEYS_REMAP = { |
| 2431 | + "qkv.conv.weight": remap_qkv_, |
| 2432 | + "proj.conv.weight": remap_proj_conv_, |
| 2433 | + } |
| 2434 | + |
| 2435 | + if "f32c32" in model_name and "sana" not in model_name: |
| 2436 | + AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS) |
| 2437 | + elif "f64c128" in model_name: |
| 2438 | + AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS) |
| 2439 | + elif "f128c512" in model_name: |
| 2440 | + AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS) |
| 2441 | + |
| 2442 | + for key in list(converted_state_dict.keys()): |
| 2443 | + new_key = key[:] |
| 2444 | + for replace_key, rename_key in AE_KEYS_RENAME_DICT.items(): |
| 2445 | + new_key = new_key.replace(replace_key, rename_key) |
| 2446 | + converted_state_dict[new_key] = converted_state_dict.pop(key) |
| 2447 | + |
| 2448 | + for key in list(converted_state_dict.keys()): |
| 2449 | + for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items(): |
| 2450 | + if special_key not in key: |
| 2451 | + continue |
| 2452 | + handler_fn_inplace(key, converted_state_dict) |
| 2453 | + |
| 2454 | + return converted_state_dict |
0 commit comments