Skip to content

Commit 2f42e40

Browse files
committed
Add AnyTextControlNetConditioningEmbedding template
1 parent 48e88eb commit 2f42e40

File tree

1 file changed

+64
-5
lines changed

1 file changed

+64
-5
lines changed

examples/research_projects/anytext/text_controlnet.py

Lines changed: 64 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
from typing import Any, Dict, Optional, Tuple, Union
1515

1616
import torch
17+
import torch.nn.functional as F
18+
from torch import nn
1719

1820
from diffusers.configuration_utils import register_to_config
1921
from diffusers.models.controlnet import (
@@ -26,6 +28,51 @@
2628
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
2729

2830

31+
class AnyTextControlNetConditioningEmbedding(nn.Module):
32+
"""
33+
Quoting from https://arxiv.org/abs/2302.05543: "Stable Diffusion uses a pre-processing method similar to VQ-GAN
34+
[11] to convert the entire dataset of 512 × 512 images into smaller 64 × 64 “latent images” for stabilized
35+
training. This requires ControlNets to convert image-based conditions to 64 × 64 feature space to match the
36+
convolution size. We use a tiny network E(·) of four convolution layers with 4 × 4 kernels and 2 × 2 strides
37+
(activated by ReLU, channels are 16, 32, 64, 128, initialized with Gaussian weights, trained jointly with the full
38+
model) to encode image-space conditions ... into feature maps ..."
39+
"""
40+
41+
def __init__(
42+
self,
43+
conditioning_embedding_channels: int,
44+
conditioning_channels: int = 3,
45+
block_out_channels: Tuple[int, ...] = (16, 32, 96, 256),
46+
):
47+
super().__init__()
48+
49+
self.conv_in = nn.Conv2d(conditioning_channels, block_out_channels[0], kernel_size=3, padding=1)
50+
51+
self.blocks = nn.ModuleList([])
52+
53+
for i in range(len(block_out_channels) - 1):
54+
channel_in = block_out_channels[i]
55+
channel_out = block_out_channels[i + 1]
56+
self.blocks.append(nn.Conv2d(channel_in, channel_in, kernel_size=3, padding=1))
57+
self.blocks.append(nn.Conv2d(channel_in, channel_out, kernel_size=3, padding=1, stride=2))
58+
59+
self.conv_out = zero_module(
60+
nn.Conv2d(block_out_channels[-1], conditioning_embedding_channels, kernel_size=3, padding=1)
61+
)
62+
63+
def forward(self, conditioning):
64+
embedding = self.conv_in(conditioning)
65+
embedding = F.silu(embedding)
66+
67+
for block in self.blocks:
68+
embedding = block(embedding)
69+
embedding = F.silu(embedding)
70+
71+
embedding = self.conv_out(embedding)
72+
73+
return embedding
74+
75+
2976
class AnyTextControlNetModel(ControlNetModel):
3077
"""
3178
A AnyTextControlNetModel model.
@@ -172,16 +219,21 @@ def __init__(
172219
global_pool_conditions,
173220
addition_embed_type_num_heads,
174221
)
175-
self.controlnet_cond_embedding = (
176-
None # TODO: Instead of this, design a custom `ControlNetConditioningEmbedding`
222+
223+
# control net conditioning embedding
224+
# TODO: what happens ControlNetModel's self.controlnet_cond_embedding's memory occupation?
225+
self.controlnet_cond_embedding = AnyTextControlNetConditioningEmbedding(
226+
conditioning_embedding_channels=block_out_channels[0],
227+
block_out_channels=conditioning_embedding_out_channels,
228+
conditioning_channels=conditioning_channels,
177229
)
178230

179231
def forward(
180232
self,
181233
sample: torch.Tensor,
182234
timestep: Union[torch.Tensor, float, int],
183235
encoder_hidden_states: torch.Tensor,
184-
guided_hint: torch.Tensor,
236+
controlnet_cond: torch.Tensor,
185237
conditioning_scale: float = 1.0,
186238
class_labels: Optional[torch.Tensor] = None,
187239
timestep_cond: Optional[torch.Tensor] = None,
@@ -310,8 +362,8 @@ def forward(
310362
# 2. pre-process
311363
sample = self.conv_in(sample)
312364

313-
# controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
314-
sample = sample + guided_hint
365+
controlnet_cond = self.controlnet_cond_embedding(controlnet_cond)
366+
sample = sample + controlnet_cond
315367

316368
# 3. down
317369
down_block_res_samples = (sample,)
@@ -375,3 +427,10 @@ def forward(
375427
return ControlNetOutput(
376428
down_block_res_samples=down_block_res_samples, mid_block_res_sample=mid_block_res_sample
377429
)
430+
431+
432+
# Copied from diffusers.models.controlnet.zero_module
433+
def zero_module(module):
434+
for p in module.parameters():
435+
nn.init.zeros_(p)
436+
return module

0 commit comments

Comments
 (0)