|
14 | 14 | from typing import Any, Dict, Optional, Tuple, Union |
15 | 15 |
|
16 | 16 | import torch |
| 17 | +import torch.nn.functional as F |
| 18 | +from torch import nn |
17 | 19 |
|
18 | 20 | from diffusers.configuration_utils import register_to_config |
19 | 21 | from diffusers.models.controlnet import ( |
|
26 | 28 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
27 | 29 |
|
28 | 30 |
|
| 31 | +class AnyTextControlNetConditioningEmbedding(nn.Module): |
| 32 | + """ |
| 33 | + Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN |
| 34 | + [11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized |
| 35 | + training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the |
| 36 | + convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides |
| 37 | + (activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full |
| 38 | + model) to encode image-space conditions ... into feature maps ..." |
| 39 | + """ |
| 40 | + |
| 41 | + def __init__( |
| 42 | + self, |
| 43 | + conditioning_embedding_channels: int, |
| 44 | + conditioning_channels: int = 3, |
| 45 | + block_out_channels: Tuple[int, ...] = (16, 32, 96, 256), |
| 46 | + ): |
| 47 | + super().__init__() |
| 48 | + |
| 49 | + self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1) |
| 50 | + |
| 51 | + self.blocks = nn.ModuleList([]) |
| 52 | + |
| 53 | + for i in range(len(block_out_channels) - 1): |
| 54 | + channel_in = block_out_channels[i] |
| 55 | + channel_out = block_out_channels[i + 1] |
| 56 | + self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1)) |
| 57 | + self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2)) |
| 58 | + |
| 59 | + self.conv_out = zero_module( |
| 60 | + nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1) |
| 61 | + ) |
| 62 | + |
| 63 | + def forward(self, conditioning): |
| 64 | + embedding = self.conv_in(conditioning) |
| 65 | + embedding = F.silu(embedding) |
| 66 | + |
| 67 | + for block in self.blocks: |
| 68 | + embedding = block(embedding) |
| 69 | + embedding = F.silu(embedding) |
| 70 | + |
| 71 | + embedding = self.conv_out(embedding) |
| 72 | + |
| 73 | + return embedding |
| 74 | + |
| 75 | + |
29 | 76 | class AnyTextControlNetModel(ControlNetModel): |
30 | 77 | """ |
31 | 78 | A AnyTextControlNetModel model. |
@@ -172,16 +219,21 @@ def __init__( |
172 | 219 | global_pool_conditions, |
173 | 220 | addition_embed_type_num_heads, |
174 | 221 | ) |
175 | | - self.controlnet_cond_embedding = ( |
176 | | - None # TODO: Instead of this, design a custom `ControlNetConditioningEmbedding` |
| 222 | + |
| 223 | + # control net conditioning embedding |
| 224 | + # TODO: what happens ControlNetModel's self.controlnet_cond_embedding's memory occupation? |
| 225 | + self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding( |
| 226 | + conditioning_embedding_channels=block_out_channels[0], |
| 227 | + block_out_channels=conditioning_embedding_out_channels, |
| 228 | + conditioning_channels=conditioning_channels, |
177 | 229 | ) |
178 | 230 |
|
179 | 231 | def forward( |
180 | 232 | self, |
181 | 233 | sample: torch.Tensor, |
182 | 234 | timestep: Union[torch.Tensor, float, int], |
183 | 235 | encoder_hidden_states: torch.Tensor, |
184 | | - guided_hint: torch.Tensor, |
| 236 | + controlnet_cond: torch.Tensor, |
185 | 237 | conditioning_scale: float = 1.0, |
186 | 238 | class_labels: Optional[torch.Tensor] = None, |
187 | 239 | timestep_cond: Optional[torch.Tensor] = None, |
@@ -310,8 +362,8 @@ def forward( |
310 | 362 | # 2. pre-process |
311 | 363 | sample = self.conv_in(sample) |
312 | 364 |
|
313 | | - # controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) |
314 | | - sample = sample + guided_hint |
| 365 | + controlnet_cond = self.controlnet_cond_embedding(controlnet_cond) |
| 366 | + sample = sample + controlnet_cond |
315 | 367 |
|
316 | 368 | # 3. down |
317 | 369 | down_block_res_samples = (sample,) |
@@ -375,3 +427,10 @@ def forward( |
375 | 427 | return ControlNetOutput( |
376 | 428 | down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample |
377 | 429 | ) |
| 430 | + |
| 431 | + |
| 432 | +# Copied from diffusers.models.controlnet.zero_module |
| 433 | +def zero_module(module): |
| 434 | + for p in module.parameters(): |
| 435 | + nn.init.zeros_(p) |
| 436 | + return module |
0 commit comments