Skip to content

Commit 5edd0b3

Browse files
authored
move vqmodel to models.autoencoders. (#8292)
move vqmodel to models.autoencoders.
1 parent 3a28e36 commit 5edd0b3

File tree

5 files changed

+194
-168
lines changed

5 files changed

+194
-168
lines changed

docs/source/en/api/models/vq.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,4 @@ The abstract from the paper is:
2424

2525
## VQEncoderOutput
2626

27-
[[autodoc]] models.vq_model.VQEncoderOutput
27+
[[autodoc]] models.autoencoders.vq_model.VQEncoderOutput

src/diffusers/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
_import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
3232
_import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
3333
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
34+
_import_structure["autoencoders.vq_model"] = ["VQModel"]
3435
_import_structure["controlnet"] = ["ControlNetModel"]
3536
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
3637
_import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
@@ -50,7 +51,6 @@
5051
_import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
5152
_import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
5253
_import_structure["unets.uvit_2d"] = ["UVit2DModel"]
53-
_import_structure["vq_model"] = ["VQModel"]
5454

5555
if is_flax_available():
5656
_import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
@@ -67,6 +67,7 @@
6767
AutoencoderKLTemporalDecoder,
6868
AutoencoderTiny,
6969
ConsistencyDecoderVAE,
70+
VQModel,
7071
)
7172
from .controlnet import ControlNetModel
7273
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
@@ -92,7 +93,6 @@
9293
UNetSpatioTemporalConditionModel,
9394
UVit2DModel,
9495
)
95-
from .vq_model import VQModel
9696

9797
if is_flax_available():
9898
from .controlnet_flax import FlaxControlNetModel

src/diffusers/models/autoencoders/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
44
from .autoencoder_tiny import AutoencoderTiny
55
from .consistency_decoder_vae import ConsistencyDecoderVAE
6+
from .vq_model import VQModel
Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,182 @@
1+
# Copyright 2024 The HuggingFace Team. All rights reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
from dataclasses import dataclass
15+
from typing import Optional, Tuple, Union
16+
17+
import torch
18+
import torch.nn as nn
19+
20+
from ...configuration_utils import ConfigMixin, register_to_config
21+
from ...utils import BaseOutput
22+
from ...utils.accelerate_utils import apply_forward_hook
23+
from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
24+
from ..modeling_utils import ModelMixin
25+
26+
27+
@dataclass
28+
class VQEncoderOutput(BaseOutput):
29+
"""
30+
Output of VQModel encoding method.
31+
32+
Args:
33+
latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
34+
The encoded output sample from the last layer of the model.
35+
"""
36+
37+
latents: torch.Tensor
38+
39+
40+
class VQModel(ModelMixin, ConfigMixin):
41+
r"""
42+
A VQ-VAE model for decoding latent representations.
43+
44+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45+
for all models (such as downloading or saving).
46+
47+
Parameters:
48+
in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49+
out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50+
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
51+
Tuple of downsample block types.
52+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
53+
Tuple of upsample block types.
54+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
55+
Tuple of block output channels.
56+
layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
57+
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58+
latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
59+
sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60+
num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
61+
norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
62+
vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
63+
scaling_factor (`float`, *optional*, defaults to `0.18215`):
64+
The component-wise standard deviation of the trained latent space computed using the first batch of the
65+
training set. This is used to scale the latent space to have unit variance when training the diffusion
66+
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
67+
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
68+
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
69+
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
70+
norm_type (`str`, *optional*, defaults to `"group"`):
71+
Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
72+
"""
73+
74+
@register_to_config
75+
def __init__(
76+
self,
77+
in_channels: int = 3,
78+
out_channels: int = 3,
79+
down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
80+
up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
81+
block_out_channels: Tuple[int, ...] = (64,),
82+
layers_per_block: int = 1,
83+
act_fn: str = "silu",
84+
latent_channels: int = 3,
85+
sample_size: int = 32,
86+
num_vq_embeddings: int = 256,
87+
norm_num_groups: int = 32,
88+
vq_embed_dim: Optional[int] = None,
89+
scaling_factor: float = 0.18215,
90+
norm_type: str = "group", # group, spatial
91+
mid_block_add_attention=True,
92+
lookup_from_codebook=False,
93+
force_upcast=False,
94+
):
95+
super().__init__()
96+
97+
# pass init params to Encoder
98+
self.encoder = Encoder(
99+
in_channels=in_channels,
100+
out_channels=latent_channels,
101+
down_block_types=down_block_types,
102+
block_out_channels=block_out_channels,
103+
layers_per_block=layers_per_block,
104+
act_fn=act_fn,
105+
norm_num_groups=norm_num_groups,
106+
double_z=False,
107+
mid_block_add_attention=mid_block_add_attention,
108+
)
109+
110+
vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
111+
112+
self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
113+
self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
114+
self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
115+
116+
# pass init params to Decoder
117+
self.decoder = Decoder(
118+
in_channels=latent_channels,
119+
out_channels=out_channels,
120+
up_block_types=up_block_types,
121+
block_out_channels=block_out_channels,
122+
layers_per_block=layers_per_block,
123+
act_fn=act_fn,
124+
norm_num_groups=norm_num_groups,
125+
norm_type=norm_type,
126+
mid_block_add_attention=mid_block_add_attention,
127+
)
128+
129+
@apply_forward_hook
130+
def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
131+
h = self.encoder(x)
132+
h = self.quant_conv(h)
133+
134+
if not return_dict:
135+
return (h,)
136+
137+
return VQEncoderOutput(latents=h)
138+
139+
@apply_forward_hook
140+
def decode(
141+
self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
142+
) -> Union[DecoderOutput, torch.Tensor]:
143+
# also go through quantization layer
144+
if not force_not_quantize:
145+
quant, commit_loss, _ = self.quantize(h)
146+
elif self.config.lookup_from_codebook:
147+
quant = self.quantize.get_codebook_entry(h, shape)
148+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
149+
else:
150+
quant = h
151+
commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
152+
quant2 = self.post_quant_conv(quant)
153+
dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
154+
155+
if not return_dict:
156+
return dec, commit_loss
157+
158+
return DecoderOutput(sample=dec, commit_loss=commit_loss)
159+
160+
def forward(
161+
self, sample: torch.Tensor, return_dict: bool = True
162+
) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
163+
r"""
164+
The [`VQModel`] forward method.
165+
166+
Args:
167+
sample (`torch.Tensor`): Input sample.
168+
return_dict (`bool`, *optional*, defaults to `True`):
169+
Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
170+
171+
Returns:
172+
[`~models.vq_model.VQEncoderOutput`] or `tuple`:
173+
If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
174+
is returned.
175+
"""
176+
177+
h = self.encode(sample).latents
178+
dec = self.decode(h)
179+
180+
if not return_dict:
181+
return dec.sample, dec.commit_loss
182+
return dec

0 commit comments

Comments
 (0)