|  | 
| 17 | 17 | 
 | 
| 18 | 18 | import regex as re | 
| 19 | 19 | import torch | 
| 20 |  | -from transformers import AutoTokenizer, UMT5EncoderModel | 
|  | 20 | +from transformers import AutoTokenizer, CLIPImageProcessor, CLIPVisionModel, UMT5EncoderModel | 
| 21 | 21 | 
 | 
| 22 | 22 | from ...configuration_utils import FrozenDict | 
| 23 | 23 | from ...guiders import ClassifierFreeGuidance | 
|  | 24 | +from ...image_processor import PipelineImageInput | 
|  | 25 | +from ...models import AutoencoderKLWan | 
| 24 | 26 | from ...utils import is_ftfy_available, logging | 
|  | 27 | +from ...video_processor import VideoProcessor | 
| 25 | 28 | from ..modular_pipeline import PipelineBlock, PipelineState | 
| 26 | 29 | from ..modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam | 
| 27 | 30 | from .modular_pipeline import WanModularPipeline | 
| @@ -51,6 +54,20 @@ def prompt_clean(text): | 
| 51 | 54 |     return text | 
| 52 | 55 | 
 | 
| 53 | 56 | 
 | 
|  | 57 | +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents | 
|  | 58 | +def retrieve_latents( | 
|  | 59 | +    encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" | 
|  | 60 | +): | 
|  | 61 | +    if hasattr(encoder_output, "latent_dist") and sample_mode == "sample": | 
|  | 62 | +        return encoder_output.latent_dist.sample(generator) | 
|  | 63 | +    elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax": | 
|  | 64 | +        return encoder_output.latent_dist.mode() | 
|  | 65 | +    elif hasattr(encoder_output, "latents"): | 
|  | 66 | +        return encoder_output.latents | 
|  | 67 | +    else: | 
|  | 68 | +        raise AttributeError("Could not access latents of provided encoder_output") | 
|  | 69 | + | 
|  | 70 | + | 
| 54 | 71 | class WanTextEncoderStep(PipelineBlock): | 
| 55 | 72 |     model_name = "wan" | 
| 56 | 73 | 
 | 
| @@ -240,3 +257,233 @@ def __call__(self, components: WanModularPipeline, state: PipelineState) -> Pipe | 
| 240 | 257 |         # Add outputs | 
| 241 | 258 |         self.set_block_state(state, block_state) | 
| 242 | 259 |         return components, state | 
|  | 260 | + | 
|  | 261 | + | 
|  | 262 | +class WanImageEncodeStep(PipelineBlock): | 
|  | 263 | +    model_name = "wan" | 
|  | 264 | + | 
|  | 265 | +    @property | 
|  | 266 | +    def description(self) -> str: | 
|  | 267 | +        return "Image Encoder step to compute image embeddings to guide the video generation" | 
|  | 268 | + | 
|  | 269 | +    @property | 
|  | 270 | +    def expected_components(self) -> List[ComponentSpec]: | 
|  | 271 | +        return [ | 
|  | 272 | +            ComponentSpec("image_encoder", CLIPVisionModel), | 
|  | 273 | +            ComponentSpec("image_processor", CLIPImageProcessor), | 
|  | 274 | +        ] | 
|  | 275 | + | 
|  | 276 | +    @property | 
|  | 277 | +    def expected_configs(self) -> List[ConfigSpec]: | 
|  | 278 | +        return [] | 
|  | 279 | + | 
|  | 280 | +    @property | 
|  | 281 | +    def inputs(self) -> List[InputParam]: | 
|  | 282 | +        return [ | 
|  | 283 | +            InputParam( | 
|  | 284 | +                "image", | 
|  | 285 | +                required=True, | 
|  | 286 | +                description="The input image to condition the generation on for first-frame conditioned video generation.", | 
|  | 287 | +            ), | 
|  | 288 | +            InputParam( | 
|  | 289 | +                "last_image", | 
|  | 290 | +                required=False, | 
|  | 291 | +                description="The last image to condition the generation on for last-frame conditioned video generation.", | 
|  | 292 | +            ), | 
|  | 293 | +        ] | 
|  | 294 | + | 
|  | 295 | +    @property | 
|  | 296 | +    def intermediate_outputs(self) -> List[OutputParam]: | 
|  | 297 | +        return [ | 
|  | 298 | +            OutputParam( | 
|  | 299 | +                "encoder_hidden_states_image", | 
|  | 300 | +                type_hint=torch.Tensor, | 
|  | 301 | +                description="image embeddings used to guide the image generation", | 
|  | 302 | +            ), | 
|  | 303 | +        ] | 
|  | 304 | + | 
|  | 305 | +    @staticmethod | 
|  | 306 | +    def check_inputs(block_state): | 
|  | 307 | +        if not isinstance(block_state.image, PipelineImageInput): | 
|  | 308 | +            raise ValueError(f"`image` has to be of type `PipelineImageInput` but is {type(block_state.image)}.") | 
|  | 309 | +        if block_state.last_image is not None and not isinstance(block_state.last_image, PipelineImageInput): | 
|  | 310 | +            raise ValueError( | 
|  | 311 | +                f"`last_image` has to be of type `PipelineImageInput` but is {type(block_state.last_image)}." | 
|  | 312 | +            ) | 
|  | 313 | + | 
|  | 314 | +    @staticmethod | 
|  | 315 | +    def encode_image( | 
|  | 316 | +        components, | 
|  | 317 | +        image: PipelineImageInput, | 
|  | 318 | +        device: torch.device, | 
|  | 319 | +    ): | 
|  | 320 | +        image = components.image_processor(images=image, return_tensors="pt").to(device) | 
|  | 321 | +        image_embeds = components.image_encoder(**image, output_hidden_states=True) | 
|  | 322 | +        return image_embeds.hidden_states[-2] | 
|  | 323 | + | 
|  | 324 | +    @torch.no_grad() | 
|  | 325 | +    def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: | 
|  | 326 | +        # Get inputs and intermediates | 
|  | 327 | +        block_state = self.get_block_state(state) | 
|  | 328 | +        self.check_inputs(block_state) | 
|  | 329 | + | 
|  | 330 | +        block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1 | 
|  | 331 | +        block_state.device = components._execution_device | 
|  | 332 | + | 
|  | 333 | +        # Encode input images | 
|  | 334 | +        image = block_state.image | 
|  | 335 | +        if block_state.last_image is not None: | 
|  | 336 | +            image = [block_state.image, block_state.last_image] | 
|  | 337 | + | 
|  | 338 | +        block_state.encoder_hidden_states_image = self.encode_image(components, image, block_state.device) | 
|  | 339 | + | 
|  | 340 | +        # Add outputs | 
|  | 341 | +        self.set_block_state(state, block_state) | 
|  | 342 | +        return components, state | 
|  | 343 | + | 
|  | 344 | + | 
|  | 345 | +class WanVaeEncoderStep(PipelineBlock): | 
|  | 346 | +    model_name = "wan" | 
|  | 347 | + | 
|  | 348 | +    @property | 
|  | 349 | +    def description(self) -> str: | 
|  | 350 | +        return ( | 
|  | 351 | +            "VAE encode step that encodes the input image/last_image to latents for conditioning the video generation" | 
|  | 352 | +        ) | 
|  | 353 | + | 
|  | 354 | +    @property | 
|  | 355 | +    def expected_components(self) -> List[ComponentSpec]: | 
|  | 356 | +        return [ | 
|  | 357 | +            ComponentSpec("vae", AutoencoderKLWan), | 
|  | 358 | +            ComponentSpec( | 
|  | 359 | +                "video_processor", | 
|  | 360 | +                VideoProcessor, | 
|  | 361 | +                config=FrozenDict({"vae_scale_factor": 8}), | 
|  | 362 | +                default_creation_method="from_config", | 
|  | 363 | +            ), | 
|  | 364 | +        ] | 
|  | 365 | + | 
|  | 366 | +    @property | 
|  | 367 | +    def inputs(self) -> List[InputParam]: | 
|  | 368 | +        return [ | 
|  | 369 | +            InputParam("image", required=True), | 
|  | 370 | +            InputParam("last_image", required=False), | 
|  | 371 | +            InputParam("height", type_hint=int), | 
|  | 372 | +            InputParam("width", type_hint=int), | 
|  | 373 | +            InputParam("num_frames", type_hint=int), | 
|  | 374 | +        ] | 
|  | 375 | + | 
|  | 376 | +    @property | 
|  | 377 | +    def intermediate_inputs(self) -> List[InputParam]: | 
|  | 378 | +        return [ | 
|  | 379 | +            InputParam("num_channels_latents", type_hint=int), | 
|  | 380 | +            InputParam("generator"), | 
|  | 381 | +            InputParam("dtype", type_hint=torch.dtype, description="Data type of model tensor inputs"), | 
|  | 382 | +        ] | 
|  | 383 | + | 
|  | 384 | +    @property | 
|  | 385 | +    def intermediate_outputs(self) -> List[OutputParam]: | 
|  | 386 | +        return [ | 
|  | 387 | +            OutputParam( | 
|  | 388 | +                "latent_condition", | 
|  | 389 | +                type_hint=torch.Tensor, | 
|  | 390 | +                description="The latents representing the reference first-frame/last-frame for conditioned video generation.", | 
|  | 391 | +            ) | 
|  | 392 | +        ] | 
|  | 393 | + | 
|  | 394 | +    def _encode_vae_image( | 
|  | 395 | +        self, | 
|  | 396 | +        components: WanModularPipeline, | 
|  | 397 | +        batch_size: int, | 
|  | 398 | +        height: int, | 
|  | 399 | +        width: int, | 
|  | 400 | +        num_frames: int, | 
|  | 401 | +        image: torch.Tensor, | 
|  | 402 | +        device: torch.device, | 
|  | 403 | +        dtype: torch.dtype, | 
|  | 404 | +        last_image: Optional[torch.Tensor] = None, | 
|  | 405 | +        generator: Optional[torch.Generator] = None, | 
|  | 406 | +    ): | 
|  | 407 | +        latent_height = height // self.vae_scale_factor_spatial | 
|  | 408 | +        latent_width = width // self.vae_scale_factor_spatial | 
|  | 409 | + | 
|  | 410 | +        latents_mean = ( | 
|  | 411 | +            torch.tensor(components.vae.config.latents_mean).view(1, self.vae.config.z_dim, 1, 1, 1).to(device, dtype) | 
|  | 412 | +        ) | 
|  | 413 | +        latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view( | 
|  | 414 | +            1, components.vae.config.z_dim, 1, 1, 1 | 
|  | 415 | +        ).to(device, dtype) | 
|  | 416 | + | 
|  | 417 | +        image = image.unsqueeze(2) | 
|  | 418 | +        if last_image is None: | 
|  | 419 | +            video_condition = torch.cat( | 
|  | 420 | +                [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2 | 
|  | 421 | +            ) | 
|  | 422 | +        else: | 
|  | 423 | +            last_image = last_image.unsqueeze(2) | 
|  | 424 | +            video_condition = torch.cat( | 
|  | 425 | +                [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 2, height, width), last_image], | 
|  | 426 | +                dim=2, | 
|  | 427 | +            ) | 
|  | 428 | +        video_condition = video_condition.to(device=device, dtype=dtype) | 
|  | 429 | + | 
|  | 430 | +        if isinstance(generator, list): | 
|  | 431 | +            latent_condition = [ | 
|  | 432 | +                retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") for _ in generator | 
|  | 433 | +            ] | 
|  | 434 | +            latent_condition = torch.cat(latent_condition) | 
|  | 435 | +        else: | 
|  | 436 | +            latent_condition = retrieve_latents(self.vae.encode(video_condition), sample_mode="argmax") | 
|  | 437 | +            latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1) | 
|  | 438 | + | 
|  | 439 | +        latent_condition = latent_condition.to(dtype) | 
|  | 440 | +        latent_condition = (latent_condition - latents_mean) * latents_std | 
|  | 441 | + | 
|  | 442 | +        mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width) | 
|  | 443 | +        if last_image is None: | 
|  | 444 | +            mask_lat_size[:, :, list(range(1, num_frames))] = 0 | 
|  | 445 | +        else: | 
|  | 446 | +            mask_lat_size[:, :, list(range(1, num_frames - 1))] = 0 | 
|  | 447 | +        first_frame_mask = mask_lat_size[:, :, 0:1] | 
|  | 448 | +        first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=self.vae_scale_factor_temporal) | 
|  | 449 | +        mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2) | 
|  | 450 | +        mask_lat_size = mask_lat_size.view(batch_size, -1, self.vae_scale_factor_temporal, latent_height, latent_width) | 
|  | 451 | +        mask_lat_size = mask_lat_size.transpose(1, 2) | 
|  | 452 | +        mask_lat_size = mask_lat_size.to(latent_condition.device) | 
|  | 453 | +        latent_condition = torch.concat([mask_lat_size, latent_condition], dim=1) | 
|  | 454 | + | 
|  | 455 | +        return latent_condition | 
|  | 456 | + | 
|  | 457 | +    @torch.no_grad() | 
|  | 458 | +    def __call__(self, components: WanModularPipeline, state: PipelineState) -> PipelineState: | 
|  | 459 | +        block_state = self.get_block_state(state) | 
|  | 460 | +        block_state.device = components._execution_device | 
|  | 461 | +        block_state.dtype = block_state.dtype if block_state.dtype is not None else components.vae.dtype | 
|  | 462 | +        block_state.num_channels_latents = self.vae.config.z_dim | 
|  | 463 | +        block_state.batch_size = ( | 
|  | 464 | +            block_state.batch_size if block_state.batch_size is not None else block_state.image.shape[0] | 
|  | 465 | +        ) | 
|  | 466 | + | 
|  | 467 | +        block_state.image = self.video_processor.preprocess( | 
|  | 468 | +            block_state.image, height=block_state.height, width=block_state.width | 
|  | 469 | +        ).to(block_state.device, dtype=torch.float32) | 
|  | 470 | +        if block_state.last_image is not None: | 
|  | 471 | +            block_state.last_image = self.video_processor.preprocess( | 
|  | 472 | +                block_state.last_image, height=block_state.height, width=block_state.width | 
|  | 473 | +            ).to(block_state.device, dtype=torch.float32) | 
|  | 474 | + | 
|  | 475 | +        block_state.latent_condition = self._encode_vae_image( | 
|  | 476 | +            components, | 
|  | 477 | +            batch_size=block_state.batch_size, | 
|  | 478 | +            height=block_state.height, | 
|  | 479 | +            width=block_state.width, | 
|  | 480 | +            num_frames=block_state.num_frames, | 
|  | 481 | +            image=block_state.image, | 
|  | 482 | +            device=block_state.device, | 
|  | 483 | +            dtype=block_state.dtype, | 
|  | 484 | +            last_image=block_state.last_image, | 
|  | 485 | +            generator=block_state.generator, | 
|  | 486 | +        ) | 
|  | 487 | + | 
|  | 488 | +        self.set_block_state(state, block_state) | 
|  | 489 | +        return components, state | 
0 commit comments