|
19 | 19 | import torch.nn as nn |
20 | 20 | import torch.nn.functional as F |
21 | 21 |
|
| 22 | +from ...configuration_utils import ConfigMixin, register_to_config |
22 | 23 | from ...utils import logging |
| 24 | +from ...utils.accelerate_utils import apply_forward_hook |
23 | 25 | from ..activations import get_activation |
| 26 | +from ..modeling_utils import ModelMixin |
| 27 | +from .vae import DecoderOutput |
24 | 28 |
|
25 | 29 |
|
26 | 30 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
@@ -175,15 +179,15 @@ def __init__( |
175 | 179 | self, |
176 | 180 | in_channels: int, |
177 | 181 | out_channels: Optional[int] = None, |
178 | | - non_linearity: str = "swish", |
| 182 | + act_fn: str = "swish", |
179 | 183 | ): |
180 | 184 | super().__init__() |
181 | 185 |
|
182 | 186 | out_channels = out_channels or in_channels |
183 | 187 |
|
184 | 188 | self.in_channels = in_channels |
185 | 189 | self.out_channels = out_channels |
186 | | - self.nonlinearity = get_activation(non_linearity) |
| 190 | + self.nonlinearity = get_activation(act_fn) |
187 | 191 |
|
188 | 192 | self.norm1 = MochiChunkedGroupNorm3D(num_channels=in_channels) |
189 | 193 | self.conv1 = MochiChunkedCausalConv3d( |
@@ -377,11 +381,11 @@ def __init__( |
377 | 381 | layers_per_block: Tuple[int, ...] = (3, 3, 4, 6, 3), |
378 | 382 | temporal_expansions: Tuple[int, ...] = (1, 2, 3), |
379 | 383 | spatial_expansions: Tuple[int, ...] = (2, 2, 2), |
380 | | - non_linearity: str = "swish", |
| 384 | + act_fn: str = "swish", |
381 | 385 | ): |
382 | 386 | super().__init__() |
383 | 387 |
|
384 | | - self.nonlinearity = get_activation(non_linearity) |
| 388 | + self.nonlinearity = get_activation(act_fn) |
385 | 389 |
|
386 | 390 | self.conv_in = nn.Conv3d(in_channels, block_out_channels[-1], kernel_size=(1, 1, 1)) |
387 | 391 | self.block_in = MochiMidBlock3D( |
@@ -436,3 +440,261 @@ def create_forward(*inputs): |
436 | 440 | hidden_states = self.conv_out(hidden_states) |
437 | 441 |
|
438 | 442 | return hidden_states |
| 443 | + |
| 444 | + |
| 445 | +class AutoencoderKLMochi(ModelMixin, ConfigMixin): |
| 446 | + r""" |
| 447 | + A VAE model with KL loss for encoding images into latents and decoding latent representations into images. Used in |
| 448 | + [Mochi 1 preview](https://github.com/genmoai/models). |
| 449 | +
|
| 450 | + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
| 451 | + for all models (such as downloading or saving). |
| 452 | +
|
| 453 | + Parameters: |
| 454 | + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. |
| 455 | + out_channels (int, *optional*, defaults to 3): Number of channels in the output. |
| 456 | + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): |
| 457 | + Tuple of block output channels. |
| 458 | + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
| 459 | + scaling_factor (`float`, *optional*, defaults to `1.15258426`): |
| 460 | + The component-wise standard deviation of the trained latent space computed using the first batch of the |
| 461 | + training set. This is used to scale the latent space to have unit variance when training the diffusion |
| 462 | + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the |
| 463 | + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 |
| 464 | + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image |
| 465 | + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. |
| 466 | + """ |
| 467 | + |
| 468 | + _supports_gradient_checkpointing = True |
| 469 | + _no_split_modules = ["MochiResnetBlock3D"] |
| 470 | + |
| 471 | + @register_to_config |
| 472 | + def __init__( |
| 473 | + self, |
| 474 | + out_channels: int = 3, |
| 475 | + block_out_channels: Tuple[int] = (128, 256, 256, 512), |
| 476 | + latent_channels: int = 12, |
| 477 | + layers_per_block: int = 3, |
| 478 | + act_fn: str = "silu", |
| 479 | + temporal_expansions: Tuple[int, ...] = (1, 2, 3), |
| 480 | + spatial_expansions: Tuple[int, ...] = (2, 2, 2), |
| 481 | + latents_mean: Tuple[float, ...] = ( |
| 482 | + -0.06730895953510081, |
| 483 | + -0.038011381506090416, |
| 484 | + -0.07477820912866141, |
| 485 | + -0.05565264470995561, |
| 486 | + 0.012767231469026969, |
| 487 | + -0.04703542746246419, |
| 488 | + 0.043896967884726704, |
| 489 | + -0.09346305707025976, |
| 490 | + -0.09918314763016893, |
| 491 | + -0.008729793427399178, |
| 492 | + -0.011931556316503654, |
| 493 | + -0.0321993391887285, |
| 494 | + ), |
| 495 | + latents_std: Tuple[float, ...] = ( |
| 496 | + 0.9263795028493863, |
| 497 | + 0.9248894543193766, |
| 498 | + 0.9393059390890617, |
| 499 | + 0.959253732819592, |
| 500 | + 0.8244560132752793, |
| 501 | + 0.917259975397747, |
| 502 | + 0.9294154431013696, |
| 503 | + 1.3720942357788521, |
| 504 | + 0.881393668867029, |
| 505 | + 0.9168315692124348, |
| 506 | + 0.9185249279345552, |
| 507 | + 0.9274757570805041, |
| 508 | + ), |
| 509 | + scaling_factor: float = 1.0, |
| 510 | + ): |
| 511 | + super().__init__() |
| 512 | + |
| 513 | + self.decoder = MochiDecoder3D( |
| 514 | + in_channels=latent_channels, |
| 515 | + out_channels=out_channels, |
| 516 | + block_out_channels=block_out_channels, |
| 517 | + layers_per_block=layers_per_block, |
| 518 | + temporal_expansions=temporal_expansions, |
| 519 | + spatial_expansions=spatial_expansions, |
| 520 | + act_fn=act_fn, |
| 521 | + ) |
| 522 | + |
| 523 | + self.use_slicing = False |
| 524 | + self.use_tiling = False |
| 525 | + |
| 526 | + def _set_gradient_checkpointing(self, module, value=False): |
| 527 | + if isinstance(module, MochiDecoder3D): |
| 528 | + module.gradient_checkpointing = value |
| 529 | + |
| 530 | + def enable_tiling( |
| 531 | + self, |
| 532 | + tile_sample_min_height: Optional[int] = None, |
| 533 | + tile_sample_min_width: Optional[int] = None, |
| 534 | + tile_overlap_factor_height: Optional[float] = None, |
| 535 | + tile_overlap_factor_width: Optional[float] = None, |
| 536 | + ) -> None: |
| 537 | + r""" |
| 538 | + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to |
| 539 | + compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow |
| 540 | + processing larger images. |
| 541 | +
|
| 542 | + Args: |
| 543 | + tile_sample_min_height (`int`, *optional*): |
| 544 | + The minimum height required for a sample to be separated into tiles across the height dimension. |
| 545 | + tile_sample_min_width (`int`, *optional*): |
| 546 | + The minimum width required for a sample to be separated into tiles across the width dimension. |
| 547 | + tile_overlap_factor_height (`int`, *optional*): |
| 548 | + The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are |
| 549 | + no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher |
| 550 | + value might cause more tiles to be processed leading to slow down of the decoding process. |
| 551 | + tile_overlap_factor_width (`int`, *optional*): |
| 552 | + The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there |
| 553 | + are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher |
| 554 | + value might cause more tiles to be processed leading to slow down of the decoding process. |
| 555 | + """ |
| 556 | + self.use_tiling = True |
| 557 | + self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height |
| 558 | + self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width |
| 559 | + self.tile_latent_min_height = int( |
| 560 | + self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1)) |
| 561 | + ) |
| 562 | + self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1))) |
| 563 | + self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height |
| 564 | + self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width |
| 565 | + |
| 566 | + def disable_tiling(self) -> None: |
| 567 | + r""" |
| 568 | + Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing |
| 569 | + decoding in one step. |
| 570 | + """ |
| 571 | + self.use_tiling = False |
| 572 | + |
| 573 | + def enable_slicing(self) -> None: |
| 574 | + r""" |
| 575 | + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to |
| 576 | + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. |
| 577 | + """ |
| 578 | + self.use_slicing = True |
| 579 | + |
| 580 | + def disable_slicing(self) -> None: |
| 581 | + r""" |
| 582 | + Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing |
| 583 | + decoding in one step. |
| 584 | + """ |
| 585 | + self.use_slicing = False |
| 586 | + |
| 587 | + @apply_forward_hook |
| 588 | + def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: |
| 589 | + """ |
| 590 | + Decode a batch of images. |
| 591 | +
|
| 592 | + Args: |
| 593 | + z (`torch.Tensor`): Input batch of latent vectors. |
| 594 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 595 | + Whether to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
| 596 | +
|
| 597 | + Returns: |
| 598 | + [`~models.vae.DecoderOutput`] or `tuple`: |
| 599 | + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
| 600 | + returned. |
| 601 | + """ |
| 602 | + if self.use_slicing and z.shape[0] > 1: |
| 603 | + decoded_slices = [self.decoder(z_slice) for z_slice in z.split(1)] |
| 604 | + decoded = torch.cat(decoded_slices) |
| 605 | + else: |
| 606 | + decoded = self.decoder(z) |
| 607 | + |
| 608 | + if not return_dict: |
| 609 | + return (decoded,) |
| 610 | + return DecoderOutput(sample=decoded) |
| 611 | + |
| 612 | + def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
| 613 | + blend_extent = min(a.shape[3], b.shape[3], blend_extent) |
| 614 | + for y in range(blend_extent): |
| 615 | + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * ( |
| 616 | + y / blend_extent |
| 617 | + ) |
| 618 | + return b |
| 619 | + |
| 620 | + def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor: |
| 621 | + blend_extent = min(a.shape[4], b.shape[4], blend_extent) |
| 622 | + for x in range(blend_extent): |
| 623 | + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * ( |
| 624 | + x / blend_extent |
| 625 | + ) |
| 626 | + return b |
| 627 | + |
| 628 | + def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]: |
| 629 | + r""" |
| 630 | + Decode a batch of images using a tiled decoder. |
| 631 | +
|
| 632 | + Args: |
| 633 | + z (`torch.Tensor`): Input batch of latent vectors. |
| 634 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 635 | + Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple. |
| 636 | +
|
| 637 | + Returns: |
| 638 | + [`~models.vae.DecoderOutput`] or `tuple`: |
| 639 | + If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is |
| 640 | + returned. |
| 641 | + """ |
| 642 | + |
| 643 | + batch_size, num_channels, num_frames, height, width = z.shape |
| 644 | + |
| 645 | + overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height)) |
| 646 | + overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width)) |
| 647 | + blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height) |
| 648 | + blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width) |
| 649 | + row_limit_height = self.tile_sample_min_height - blend_extent_height |
| 650 | + row_limit_width = self.tile_sample_min_width - blend_extent_width |
| 651 | + frame_batch_size = self.num_latent_frames_batch_size |
| 652 | + |
| 653 | + # Split z into overlapping tiles and decode them separately. |
| 654 | + # The tiles have an overlap to avoid seams between tiles. |
| 655 | + rows = [] |
| 656 | + for i in range(0, height, overlap_height): |
| 657 | + row = [] |
| 658 | + for j in range(0, width, overlap_width): |
| 659 | + num_batches = max(num_frames // frame_batch_size, 1) |
| 660 | + conv_cache = None |
| 661 | + time = [] |
| 662 | + |
| 663 | + for k in range(num_batches): |
| 664 | + remaining_frames = num_frames % frame_batch_size |
| 665 | + start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames) |
| 666 | + end_frame = frame_batch_size * (k + 1) + remaining_frames |
| 667 | + tile = z[ |
| 668 | + :, |
| 669 | + :, |
| 670 | + start_frame:end_frame, |
| 671 | + i : i + self.tile_latent_min_height, |
| 672 | + j : j + self.tile_latent_min_width, |
| 673 | + ] |
| 674 | + if self.post_quant_conv is not None: |
| 675 | + tile = self.post_quant_conv(tile) |
| 676 | + tile, conv_cache = self.decoder(tile, conv_cache=conv_cache) |
| 677 | + time.append(tile) |
| 678 | + |
| 679 | + row.append(torch.cat(time, dim=2)) |
| 680 | + rows.append(row) |
| 681 | + |
| 682 | + result_rows = [] |
| 683 | + for i, row in enumerate(rows): |
| 684 | + result_row = [] |
| 685 | + for j, tile in enumerate(row): |
| 686 | + # blend the above tile and the left tile |
| 687 | + # to the current tile and add the current tile to the result row |
| 688 | + if i > 0: |
| 689 | + tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height) |
| 690 | + if j > 0: |
| 691 | + tile = self.blend_h(row[j - 1], tile, blend_extent_width) |
| 692 | + result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width]) |
| 693 | + result_rows.append(torch.cat(result_row, dim=4)) |
| 694 | + |
| 695 | + dec = torch.cat(result_rows, dim=3) |
| 696 | + |
| 697 | + if not return_dict: |
| 698 | + return (dec,) |
| 699 | + |
| 700 | + return DecoderOutput(sample=dec) |
0 commit comments