|
74 | 74 | "stable_cascade_stage_b": "down_blocks.1.0.channelwise.0.weight", |
75 | 75 | "stable_cascade_stage_c": "clip_txt_mapper.weight", |
76 | 76 | "sd3": "model.diffusion_model.joint_blocks.0.context_block.adaLN_modulation.1.bias", |
77 | | - "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.1.pos_encoder.pe", |
| 77 | + "animatediff": "down_blocks.0.motion_modules.0.temporal_transformer.transformer_blocks.0.attention_blocks.0.pos_encoder.pe", |
78 | 78 | "animatediff_v2": "mid_block.motion_modules.0.temporal_transformer.norm.bias", |
79 | 79 | "animatediff_sdxl_beta": "up_blocks.2.motion_modules.0.temporal_transformer.norm.weight", |
| 80 | + "animatediff_scribble": "controlnet_cond_embedding.conv_in.weight", |
| 81 | + "animatediff_rgb": "controlnet_cond_embedding.weight", |
80 | 82 | "flux": "double_blocks.0.img_attn.norm.key_norm.scale", |
81 | 83 | } |
82 | 84 |
|
|
111 | 113 | "animatediff_v2": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-2"}, |
112 | 114 | "animatediff_v3": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-v1-5-3"}, |
113 | 115 | "animatediff_sdxl_beta": {"pretrained_model_name_or_path": "guoyww/animatediff-motion-adapter-sdxl-beta"}, |
| 116 | + "animatediff_scribble": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-scribble"}, |
| 117 | + "animatediff_rgb": {"pretrained_model_name_or_path": "guoyww/animatediff-sparsectrl-rgb"}, |
114 | 118 | "flux-dev": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-dev"}, |
115 | 119 | "flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"}, |
116 | 120 | } |
@@ -494,7 +498,13 @@ def infer_diffusers_model_type(checkpoint): |
494 | 498 | model_type = "sd3" |
495 | 499 |
|
496 | 500 | elif CHECKPOINT_KEY_NAMES["animatediff"] in checkpoint: |
497 | | - if CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: |
| 501 | + if CHECKPOINT_KEY_NAMES["animatediff_scribble"] in checkpoint: |
| 502 | + model_type = "animatediff_scribble" |
| 503 | + |
| 504 | + elif CHECKPOINT_KEY_NAMES["animatediff_rgb"] in checkpoint: |
| 505 | + model_type = "animatediff_rgb" |
| 506 | + |
| 507 | + elif CHECKPOINT_KEY_NAMES["animatediff_v2"] in checkpoint: |
498 | 508 | model_type = "animatediff_v2" |
499 | 509 |
|
500 | 510 | elif checkpoint[CHECKPOINT_KEY_NAMES["animatediff_sdxl_beta"]].shape[-1] == 320: |
|
0 commit comments