Skip to content

Commit d6c748c

Browse files
committed
from original file loader
1 parent 632ad3b commit d6c748c

File tree

4 files changed

+264
-7
lines changed

4 files changed

+264
-7
lines changed

scripts/convert_dcae_to_diffusers.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,6 @@
88
from diffusers import AutoencoderDC
99

1010

11-
def remove_keys_(key: str, state_dict: Dict[str, Any]):
12-
state_dict.pop(key)
13-
14-
1511
def remap_qkv_(key: str, state_dict: Dict[str, Any]):
1612
qkv = state_dict.pop(key)
1713
q, k, v = torch.chunk(qkv, 3, dim=0)

src/diffusers/loaders/single_file_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,12 +23,14 @@
2323
from .single_file_utils import (
2424
SingleFileComponentError,
2525
convert_animatediff_checkpoint_to_diffusers,
26+
convert_autoencoder_dc_checkpoint_to_diffusers,
2627
convert_controlnet_checkpoint,
2728
convert_flux_transformer_checkpoint_to_diffusers,
2829
convert_ldm_unet_checkpoint,
2930
convert_ldm_vae_checkpoint,
3031
convert_sd3_transformer_checkpoint_to_diffusers,
3132
convert_stable_cascade_unet_single_file_to_diffusers,
33+
create_autoencoder_dc_config_from_original,
3234
create_controlnet_diffusers_config_from_ldm,
3335
create_unet_diffusers_config_from_ldm,
3436
create_vae_diffusers_config_from_ldm,
@@ -82,6 +84,10 @@
8284
"checkpoint_mapping_fn": convert_flux_transformer_checkpoint_to_diffusers,
8385
"default_subfolder": "transformer",
8486
},
87+
"AutoencoderDC": {
88+
"checkpoint_mapping_fn": convert_autoencoder_dc_checkpoint_to_diffusers,
89+
"config_mapping_fn": create_autoencoder_dc_config_from_original,
90+
},
8591
}
8692

8793

@@ -228,7 +234,7 @@ def from_single_file(cls, pretrained_model_link_or_path_or_dict: Optional[str] =
228234
if config_mapping_fn is None:
229235
raise ValueError(
230236
(
231-
f"`original_config` has been provided for {mapping_class_name} but no mapping function"
237+
f"`original_config` has been provided for {mapping_class_name} but no mapping function "
232238
"was found to convert the original config to a Diffusers config in"
233239
"`diffusers.loaders.single_file_utils`"
234240
)

src/diffusers/loaders/single_file_utils.py

Lines changed: 255 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,12 @@
1212
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
15-
"""Conversion script for the Stable Diffusion checkpoints."""
15+
16+
"""
17+
Conversion scripts for the various modeling checkpoints. These scripts convert original model implementations to
18+
Diffusers adapted versions. This usually only involves renaming/remapping the state dict keys and changing some
19+
modeling components partially (for example, splitting a single QKV linear to individual Q, K, V layers).
20+
"""
1621

1722
import copy
1823
import os
@@ -92,6 +97,7 @@
9297
"double_blocks.0.img_attn.norm.key_norm.scale",
9398
"model.diffusion_model.double_blocks.0.img_attn.norm.key_norm.scale",
9499
],
100+
"autoencoder_dc": "decoder.stages.0.op_list.0.main.conv.conv.weight",
95101
}
96102

97103
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -2198,3 +2204,251 @@ def swap_scale_shift(weight):
21982204
)
21992205

22002206
return converted_state_dict
2207+
2208+
2209+
def create_autoencoder_dc_config_from_original(original_config, checkpoint, **kwargs):
2210+
model_name = original_config.get("model_name", "dc-ae-f32c32-sana-1.0")
2211+
print("trying:", model_name)
2212+
2213+
if model_name in ["dc-ae-f32c32-sana-1.0"]:
2214+
config = {
2215+
"latent_channels": 32,
2216+
"encoder_block_types": (
2217+
"ResBlock",
2218+
"ResBlock",
2219+
"ResBlock",
2220+
"EfficientViTBlock",
2221+
"EfficientViTBlock",
2222+
"EfficientViTBlock",
2223+
),
2224+
"decoder_block_types": (
2225+
"ResBlock",
2226+
"ResBlock",
2227+
"ResBlock",
2228+
"EfficientViTBlock",
2229+
"EfficientViTBlock",
2230+
"EfficientViTBlock",
2231+
),
2232+
"encoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
2233+
"decoder_block_out_channels": (128, 256, 512, 512, 1024, 1024),
2234+
"encoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
2235+
"decoder_qkv_multiscales": ((), (), (), (5,), (5,), (5,)),
2236+
"encoder_layers_per_block": (2, 2, 2, 3, 3, 3),
2237+
"decoder_layers_per_block": [3, 3, 3, 3, 3, 3],
2238+
"downsample_block_type": "conv",
2239+
"upsample_block_type": "interpolate",
2240+
"decoder_norm_types": "rms_norm",
2241+
"decoder_act_fns": "silu",
2242+
"scaling_factor": 0.41407,
2243+
}
2244+
elif model_name in ["dc-ae-f32c32-in-1.0", "dc-ae-f32c32-mix-1.0"]:
2245+
config = {
2246+
"latent_channels": 32,
2247+
"encoder_block_types": [
2248+
"ResBlock",
2249+
"ResBlock",
2250+
"ResBlock",
2251+
"EfficientViTBlock",
2252+
"EfficientViTBlock",
2253+
"EfficientViTBlock",
2254+
],
2255+
"decoder_block_types": [
2256+
"ResBlock",
2257+
"ResBlock",
2258+
"ResBlock",
2259+
"EfficientViTBlock",
2260+
"EfficientViTBlock",
2261+
"EfficientViTBlock",
2262+
],
2263+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
2264+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024],
2265+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2],
2266+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2],
2267+
"encoder_qkv_multiscales": ((), (), (), (), (), ()),
2268+
"decoder_qkv_multiscales": ((), (), (), (), (), ()),
2269+
"decoder_norm_types": ["batch_norm", "batch_norm", "batch_norm", "rms_norm", "rms_norm", "rms_norm"],
2270+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu"],
2271+
}
2272+
if model_name == "dc-ae-f32c32-in-1.0":
2273+
config["scaling_factor"] = 0.3189
2274+
elif model_name == "dc-ae-f32c32-mix-1.0":
2275+
config["scaling_factor"] = 0.4552
2276+
elif model_name in ["dc-ae-f64c128-in-1.0", "dc-ae-f64c128-mix-1.0"]:
2277+
config = {
2278+
"latent_channels": 128,
2279+
"encoder_block_types": [
2280+
"ResBlock",
2281+
"ResBlock",
2282+
"ResBlock",
2283+
"EfficientViTBlock",
2284+
"EfficientViTBlock",
2285+
"EfficientViTBlock",
2286+
"EfficientViTBlock",
2287+
],
2288+
"decoder_block_types": [
2289+
"ResBlock",
2290+
"ResBlock",
2291+
"ResBlock",
2292+
"EfficientViTBlock",
2293+
"EfficientViTBlock",
2294+
"EfficientViTBlock",
2295+
"EfficientViTBlock",
2296+
],
2297+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
2298+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048],
2299+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2],
2300+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2],
2301+
"encoder_qkv_multiscales": ((), (), (), (), (), (), ()),
2302+
"decoder_qkv_multiscales": ((), (), (), (), (), (), ()),
2303+
"decoder_norm_types": [
2304+
"batch_norm",
2305+
"batch_norm",
2306+
"batch_norm",
2307+
"rms_norm",
2308+
"rms_norm",
2309+
"rms_norm",
2310+
"rms_norm",
2311+
],
2312+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu"],
2313+
}
2314+
if model_name == "dc-ae-f64c128-in-1.0":
2315+
config["scaling_factor"] = 0.2889
2316+
elif model_name == "dc-ae-f64c128-mix-1.0":
2317+
config["scaling_factor"] = 0.4538
2318+
elif model_name in ["dc-ae-f128c512-in-1.0", "dc-ae-f128c512-mix-1.0"]:
2319+
config = {
2320+
"latent_channels": 512,
2321+
"encoder_block_types": [
2322+
"ResBlock",
2323+
"ResBlock",
2324+
"ResBlock",
2325+
"EfficientViTBlock",
2326+
"EfficientViTBlock",
2327+
"EfficientViTBlock",
2328+
"EfficientViTBlock",
2329+
"EfficientViTBlock",
2330+
],
2331+
"decoder_block_types": [
2332+
"ResBlock",
2333+
"ResBlock",
2334+
"ResBlock",
2335+
"EfficientViTBlock",
2336+
"EfficientViTBlock",
2337+
"EfficientViTBlock",
2338+
"EfficientViTBlock",
2339+
"EfficientViTBlock",
2340+
],
2341+
"encoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
2342+
"decoder_block_out_channels": [128, 256, 512, 512, 1024, 1024, 2048, 2048],
2343+
"encoder_layers_per_block": [0, 4, 8, 2, 2, 2, 2, 2],
2344+
"decoder_layers_per_block": [0, 5, 10, 2, 2, 2, 2, 2],
2345+
"encoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
2346+
"decoder_qkv_multiscales": ((), (), (), (), (), (), (), ()),
2347+
"decoder_norm_types": [
2348+
"batch_norm",
2349+
"batch_norm",
2350+
"batch_norm",
2351+
"rms_norm",
2352+
"rms_norm",
2353+
"rms_norm",
2354+
"rms_norm",
2355+
"rms_norm",
2356+
],
2357+
"decoder_act_fns": ["relu", "relu", "relu", "silu", "silu", "silu", "silu", "silu"],
2358+
}
2359+
if model_name == "dc-ae-f128c512-in-1.0":
2360+
config["scaling_factor"] = 0.4883
2361+
elif model_name == "dc-ae-f128c512-mix-1.0":
2362+
config["scaling_factor"] = 0.3620
2363+
2364+
config.update({"model_name": model_name})
2365+
2366+
return config
2367+
2368+
2369+
def convert_autoencoder_dc_checkpoint_to_diffusers(config, checkpoint, **kwargs):
2370+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
2371+
model_name = config.pop("model_name")
2372+
2373+
def remap_qkv_(key: str, state_dict):
2374+
qkv = state_dict.pop(key)
2375+
q, k, v = torch.chunk(qkv, 3, dim=0)
2376+
parent_module, _, _ = key.rpartition(".qkv.conv.weight")
2377+
state_dict[f"{parent_module}.to_q.weight"] = q.squeeze()
2378+
state_dict[f"{parent_module}.to_k.weight"] = k.squeeze()
2379+
state_dict[f"{parent_module}.to_v.weight"] = v.squeeze()
2380+
2381+
def remap_proj_conv_(key: str, state_dict):
2382+
parent_module, _, _ = key.rpartition(".proj.conv.weight")
2383+
state_dict[f"{parent_module}.to_out.weight"] = state_dict.pop(key).squeeze()
2384+
2385+
AE_KEYS_RENAME_DICT = {
2386+
# common
2387+
"main.": "",
2388+
"op_list.": "",
2389+
"context_module": "attn",
2390+
"local_module": "conv_out",
2391+
# NOTE: The below two lines work because scales in the available configs only have a tuple length of 1
2392+
# If there were more scales, there would be more layers, so a loop would be better to handle this
2393+
"aggreg.0.0": "to_qkv_multiscale.0.proj_in",
2394+
"aggreg.0.1": "to_qkv_multiscale.0.proj_out",
2395+
"depth_conv.conv": "conv_depth",
2396+
"inverted_conv.conv": "conv_inverted",
2397+
"point_conv.conv": "conv_point",
2398+
"point_conv.norm": "norm",
2399+
"conv.conv.": "conv.",
2400+
"conv1.conv": "conv1",
2401+
"conv2.conv": "conv2",
2402+
"conv2.norm": "norm",
2403+
"proj.norm": "norm_out",
2404+
# encoder
2405+
"encoder.project_in.conv": "encoder.conv_in",
2406+
"encoder.project_out.0.conv": "encoder.conv_out",
2407+
"encoder.stages": "encoder.down_blocks",
2408+
# decoder
2409+
"decoder.project_in.conv": "decoder.conv_in",
2410+
"decoder.project_out.0": "decoder.norm_out",
2411+
"decoder.project_out.2.conv": "decoder.conv_out",
2412+
"decoder.stages": "decoder.up_blocks",
2413+
}
2414+
2415+
AE_F32C32_KEYS = {
2416+
"encoder.project_in.conv": "encoder.conv_in.conv",
2417+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
2418+
}
2419+
2420+
AE_F64C128_KEYS = {
2421+
"encoder.project_in.conv": "encoder.conv_in.conv",
2422+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
2423+
}
2424+
2425+
AE_F128C512_KEYS = {
2426+
"encoder.project_in.conv": "encoder.conv_in.conv",
2427+
"decoder.project_out.2.conv": "decoder.conv_out.conv",
2428+
}
2429+
2430+
AE_SPECIAL_KEYS_REMAP = {
2431+
"qkv.conv.weight": remap_qkv_,
2432+
"proj.conv.weight": remap_proj_conv_,
2433+
}
2434+
2435+
if "f32c32" in model_name and "sana" not in model_name:
2436+
AE_KEYS_RENAME_DICT.update(AE_F32C32_KEYS)
2437+
elif "f64c128" in model_name:
2438+
AE_KEYS_RENAME_DICT.update(AE_F64C128_KEYS)
2439+
elif "f128c512" in model_name:
2440+
AE_KEYS_RENAME_DICT.update(AE_F128C512_KEYS)
2441+
2442+
for key in list(converted_state_dict.keys()):
2443+
new_key = key[:]
2444+
for replace_key, rename_key in AE_KEYS_RENAME_DICT.items():
2445+
new_key = new_key.replace(replace_key, rename_key)
2446+
converted_state_dict[new_key] = converted_state_dict.pop(key)
2447+
2448+
for key in list(converted_state_dict.keys()):
2449+
for special_key, handler_fn_inplace in AE_SPECIAL_KEYS_REMAP.items():
2450+
if special_key not in key:
2451+
continue
2452+
handler_fn_inplace(key, converted_state_dict)
2453+
2454+
return converted_state_dict

src/diffusers/models/autoencoders/autoencoder_dc.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import torch.nn.functional as F
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
23+
from ...loaders import FromOriginalModelMixin
2324
from ..activations import get_activation
2425
from ..attention_processor import SanaMultiscaleLinearAttention
2526
from ..modeling_utils import ModelMixin
@@ -394,7 +395,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394395
return hidden_states
395396

396397

397-
class AutoencoderDC(ModelMixin, ConfigMixin):
398+
class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
398399
r"""
399400
An Autoencoder model introduced in [DCAE](https://arxiv.org/abs/2410.10733) and used in
400401
[SANA](https://arxiv.org/abs/2410.10629).

0 commit comments

Comments
 (0)