|
23 | 23 | from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel |
24 | 24 |
|
25 | 25 | from ...callbacks import MultiPipelineCallbacks, PipelineCallback |
26 | | -from ...image_processor import PipelineImageInput, VaeImageProcessor |
| 26 | +from ...image_processor import PipelineImageInput |
27 | 27 | from ...loaders import WanLoraLoaderMixin |
28 | 28 | from ...models import AutoencoderKLWan, WanAnimateTransformer3DModel |
29 | 29 | from ...schedulers import UniPCMultistepScheduler |
@@ -978,9 +978,9 @@ def __call__( |
978 | 978 | image_height, image_width = self.video_processor.get_default_height_width(image) |
979 | 979 | if image_height != height or image_width != width: |
980 | 980 | logger.warning(f"Reshaping reference image from ({image_width}, {image_height}) to ({width}, {height})") |
981 | | - image_pixels = self.vae_image_processor.preprocess( |
982 | | - image, height=height, width=width, resize_mode="fill" |
983 | | - ).to(device, dtype=torch.float32) |
| 981 | + image_pixels = self.vae_image_processor.preprocess(image, height=height, width=width, resize_mode="fill").to( |
| 982 | + device, dtype=torch.float32 |
| 983 | + ) |
984 | 984 |
|
985 | 985 | # Get CLIP features from the reference image |
986 | 986 | if image_embeds is None: |
@@ -1174,9 +1174,9 @@ def __call__( |
1174 | 1174 | .view(1, self.vae.config.z_dim, 1, 1, 1) |
1175 | 1175 | .to(latents.device, latents.dtype) |
1176 | 1176 | ) |
1177 | | - latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view(1, self.vae.config.z_dim, 1, 1, 1).to( |
1178 | | - latents.device, latents.dtype |
1179 | | - ) |
| 1177 | + latents_recip_std = 1.0 / torch.tensor(self.vae.config.latents_std).view( |
| 1178 | + 1, self.vae.config.z_dim, 1, 1, 1 |
| 1179 | + ).to(latents.device, latents.dtype) |
1180 | 1180 | latents = latents / latents_recip_std + latents_mean |
1181 | 1181 | # Skip the first latent frame (used for conditioning) |
1182 | 1182 | out_frames = self.vae.decode(latents[:, :, 1:], return_dict=False)[0] |
|
0 commit comments