|
| 1 | +# Copyright 2024 The HuggingFace Team. All rights reserved. |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | +from dataclasses import dataclass |
| 15 | +from typing import Optional, Tuple, Union |
| 16 | + |
| 17 | +import torch |
| 18 | +import torch.nn as nn |
| 19 | + |
| 20 | +from ...configuration_utils import ConfigMixin, register_to_config |
| 21 | +from ...utils import BaseOutput |
| 22 | +from ...utils.accelerate_utils import apply_forward_hook |
| 23 | +from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer |
| 24 | +from ..modeling_utils import ModelMixin |
| 25 | + |
| 26 | + |
| 27 | +@dataclass |
| 28 | +class VQEncoderOutput(BaseOutput): |
| 29 | + """ |
| 30 | + Output of VQModel encoding method. |
| 31 | +
|
| 32 | + Args: |
| 33 | + latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`): |
| 34 | + The encoded output sample from the last layer of the model. |
| 35 | + """ |
| 36 | + |
| 37 | + latents: torch.Tensor |
| 38 | + |
| 39 | + |
| 40 | +class VQModel(ModelMixin, ConfigMixin): |
| 41 | + r""" |
| 42 | + A VQ-VAE model for decoding latent representations. |
| 43 | +
|
| 44 | + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented |
| 45 | + for all models (such as downloading or saving). |
| 46 | +
|
| 47 | + Parameters: |
| 48 | + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. |
| 49 | + out_channels (int, *optional*, defaults to 3): Number of channels in the output. |
| 50 | + down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`): |
| 51 | + Tuple of downsample block types. |
| 52 | + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`): |
| 53 | + Tuple of upsample block types. |
| 54 | + block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`): |
| 55 | + Tuple of block output channels. |
| 56 | + layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block. |
| 57 | + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. |
| 58 | + latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space. |
| 59 | + sample_size (`int`, *optional*, defaults to `32`): Sample input size. |
| 60 | + num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE. |
| 61 | + norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers. |
| 62 | + vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE. |
| 63 | + scaling_factor (`float`, *optional*, defaults to `0.18215`): |
| 64 | + The component-wise standard deviation of the trained latent space computed using the first batch of the |
| 65 | + training set. This is used to scale the latent space to have unit variance when training the diffusion |
| 66 | + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the |
| 67 | + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 |
| 68 | + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image |
| 69 | + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. |
| 70 | + norm_type (`str`, *optional*, defaults to `"group"`): |
| 71 | + Type of normalization layer to use. Can be one of `"group"` or `"spatial"`. |
| 72 | + """ |
| 73 | + |
| 74 | + @register_to_config |
| 75 | + def __init__( |
| 76 | + self, |
| 77 | + in_channels: int = 3, |
| 78 | + out_channels: int = 3, |
| 79 | + down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",), |
| 80 | + up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",), |
| 81 | + block_out_channels: Tuple[int, ...] = (64,), |
| 82 | + layers_per_block: int = 1, |
| 83 | + act_fn: str = "silu", |
| 84 | + latent_channels: int = 3, |
| 85 | + sample_size: int = 32, |
| 86 | + num_vq_embeddings: int = 256, |
| 87 | + norm_num_groups: int = 32, |
| 88 | + vq_embed_dim: Optional[int] = None, |
| 89 | + scaling_factor: float = 0.18215, |
| 90 | + norm_type: str = "group", # group, spatial |
| 91 | + mid_block_add_attention=True, |
| 92 | + lookup_from_codebook=False, |
| 93 | + force_upcast=False, |
| 94 | + ): |
| 95 | + super().__init__() |
| 96 | + |
| 97 | + # pass init params to Encoder |
| 98 | + self.encoder = Encoder( |
| 99 | + in_channels=in_channels, |
| 100 | + out_channels=latent_channels, |
| 101 | + down_block_types=down_block_types, |
| 102 | + block_out_channels=block_out_channels, |
| 103 | + layers_per_block=layers_per_block, |
| 104 | + act_fn=act_fn, |
| 105 | + norm_num_groups=norm_num_groups, |
| 106 | + double_z=False, |
| 107 | + mid_block_add_attention=mid_block_add_attention, |
| 108 | + ) |
| 109 | + |
| 110 | + vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels |
| 111 | + |
| 112 | + self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1) |
| 113 | + self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False) |
| 114 | + self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1) |
| 115 | + |
| 116 | + # pass init params to Decoder |
| 117 | + self.decoder = Decoder( |
| 118 | + in_channels=latent_channels, |
| 119 | + out_channels=out_channels, |
| 120 | + up_block_types=up_block_types, |
| 121 | + block_out_channels=block_out_channels, |
| 122 | + layers_per_block=layers_per_block, |
| 123 | + act_fn=act_fn, |
| 124 | + norm_num_groups=norm_num_groups, |
| 125 | + norm_type=norm_type, |
| 126 | + mid_block_add_attention=mid_block_add_attention, |
| 127 | + ) |
| 128 | + |
| 129 | + @apply_forward_hook |
| 130 | + def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput: |
| 131 | + h = self.encoder(x) |
| 132 | + h = self.quant_conv(h) |
| 133 | + |
| 134 | + if not return_dict: |
| 135 | + return (h,) |
| 136 | + |
| 137 | + return VQEncoderOutput(latents=h) |
| 138 | + |
| 139 | + @apply_forward_hook |
| 140 | + def decode( |
| 141 | + self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None |
| 142 | + ) -> Union[DecoderOutput, torch.Tensor]: |
| 143 | + # also go through quantization layer |
| 144 | + if not force_not_quantize: |
| 145 | + quant, commit_loss, _ = self.quantize(h) |
| 146 | + elif self.config.lookup_from_codebook: |
| 147 | + quant = self.quantize.get_codebook_entry(h, shape) |
| 148 | + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) |
| 149 | + else: |
| 150 | + quant = h |
| 151 | + commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype) |
| 152 | + quant2 = self.post_quant_conv(quant) |
| 153 | + dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None) |
| 154 | + |
| 155 | + if not return_dict: |
| 156 | + return dec, commit_loss |
| 157 | + |
| 158 | + return DecoderOutput(sample=dec, commit_loss=commit_loss) |
| 159 | + |
| 160 | + def forward( |
| 161 | + self, sample: torch.Tensor, return_dict: bool = True |
| 162 | + ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]: |
| 163 | + r""" |
| 164 | + The [`VQModel`] forward method. |
| 165 | +
|
| 166 | + Args: |
| 167 | + sample (`torch.Tensor`): Input sample. |
| 168 | + return_dict (`bool`, *optional*, defaults to `True`): |
| 169 | + Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple. |
| 170 | +
|
| 171 | + Returns: |
| 172 | + [`~models.vq_model.VQEncoderOutput`] or `tuple`: |
| 173 | + If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple` |
| 174 | + is returned. |
| 175 | + """ |
| 176 | + |
| 177 | + h = self.encode(sample).latents |
| 178 | + dec = self.decode(h) |
| 179 | + |
| 180 | + if not return_dict: |
| 181 | + return dec.sample, dec.commit_loss |
| 182 | + return dec |
0 commit comments