|
4 | 4 | from einops import rearrange, repeat |
5 | 5 | import lightning as pl |
6 | 6 | from diffsynth import ModelManager, SVDImageEncoder, SVDUNet, SVDVAEEncoder, ContinuousODEScheduler, load_state_dict |
7 | | -from diffsynth.pipelines.stable_video_diffusion import SVDCLIPImageProcessor |
| 7 | +from diffsynth.pipelines.svd_video import SVDCLIPImageProcessor |
8 | 8 | from diffsynth.models.svd_unet import TemporalAttentionBlock |
9 | 9 |
|
10 | 10 |
|
@@ -131,14 +131,14 @@ def __init__(self, learning_rate=1e-5, svd_ckpt_path=None, add_positional_conv=1 |
131 | 131 | self.image_encoder.requires_grad_(False) |
132 | 132 |
|
133 | 133 | self.unet = SVDUNet(add_positional_conv=add_positional_conv).to(dtype=torch.float16, device=self.device) |
134 | | - self.unet.load_state_dict(SVDUNet.state_dict_converter().from_civitai(state_dict), strict=False) |
| 134 | + self.unet.load_state_dict(SVDUNet.state_dict_converter().from_civitai(state_dict, add_positional_conv=add_positional_conv), strict=False) |
135 | 135 | self.unet.train() |
136 | 136 | self.unet.requires_grad_(False) |
137 | 137 | for block in self.unet.blocks: |
138 | 138 | if isinstance(block, TemporalAttentionBlock): |
139 | 139 | block.requires_grad_(True) |
140 | 140 |
|
141 | | - self.vae_encoder = SVDVAEEncoder.to(dtype=torch.float16, device=self.device) |
| 141 | + self.vae_encoder = SVDVAEEncoder().to(dtype=torch.float16, device=self.device) |
142 | 142 | self.vae_encoder.load_state_dict(SVDVAEEncoder.state_dict_converter().from_civitai(state_dict)) |
143 | 143 | self.vae_encoder.eval() |
144 | 144 | self.vae_encoder.requires_grad_(False) |
|
0 commit comments