Skip to content

Commit 816daad

Browse files
committed
support openai/adm with minimal code change
1 parent 8581d9b commit 816daad

File tree

4 files changed

+248
-47
lines changed

4 files changed

+248
-47
lines changed

src/diffusers/models/attention_processor.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import math
1415
from importlib import import_module
1516
from typing import Callable, Optional, Union
1617

@@ -707,6 +708,61 @@ def fuse_projections(self, fuse=True):
707708
self.fused_projections = fuse
708709

709710

711+
class QKVAttentionADM(nn.Module):
712+
def __init__(self, n_heads):
713+
super().__init__()
714+
self.n_heads = n_heads
715+
716+
def forward(self, qkv):
717+
bs, width, length = qkv.shape
718+
assert width % (3 * self.n_heads) == 0
719+
ch = width // (3 * self.n_heads)
720+
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
721+
scale = 1 / math.sqrt(math.sqrt(ch))
722+
weight = torch.einsum("bct,bcs->bts", q * scale, k * scale) # More stable with f16 than dividing afterwards
723+
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
724+
a = torch.einsum("bts,bcs->bct", weight, v)
725+
return a.reshape(bs, -1, length)
726+
727+
728+
class AttentionADM(nn.Module):
729+
def __init__(
730+
self,
731+
channels,
732+
num_heads=1,
733+
num_head_channels=-1,
734+
):
735+
super().__init__()
736+
self.channels = channels
737+
if num_head_channels == -1:
738+
self.num_heads = num_heads
739+
else:
740+
assert (
741+
channels % num_head_channels == 0
742+
), f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
743+
self.num_heads = channels // num_head_channels
744+
# print(channels, self.num_heads, num_head_channels)
745+
self.norm = nn.GroupNorm(32, channels)
746+
self.qkv = nn.Conv1d(channels, channels * 3, 1, 1)
747+
self.attention = QKVAttentionADM(self.num_heads)
748+
self.proj_out = nn.Conv1d(channels, channels, 1, 1)
749+
750+
def forward(
751+
self,
752+
hidden_states: torch.FloatTensor,
753+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
754+
attention_mask: Optional[torch.FloatTensor] = None,
755+
**cross_attention_kwargs,
756+
):
757+
# ignore temb and kwargs for now
758+
b, c, *spatial = hidden_states.shape
759+
hidden_states = hidden_states.reshape(b, c, -1)
760+
qkv = self.qkv(self.norm(hidden_states))
761+
h = self.attention(qkv)
762+
h = self.proj_out(h)
763+
return (hidden_states + h).reshape(b, c, *spatial)
764+
765+
710766
class AttnProcessor:
711767
r"""
712768
Default processor for performing attention-related computations.

src/diffusers/models/embeddings.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,38 @@ def forward(self, timesteps):
254254
return t_emb
255255

256256

257+
def timestep_embedding_adm(timesteps, dim, max_period=10000):
258+
"""
259+
ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103
260+
"""
261+
half = dim // 2
262+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
263+
device=timesteps.device
264+
)
265+
args = timesteps[:, None].float() * freqs[None]
266+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
267+
if dim % 2:
268+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
269+
return embedding
270+
271+
272+
class TimestepsADM(nn.Module):
273+
"""
274+
ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103
275+
"""
276+
277+
def __init__(self, num_channels: int):
278+
super().__init__()
279+
self.num_channels = num_channels
280+
281+
def forward(self, timesteps):
282+
t_emb = timestep_embedding_adm(
283+
timesteps,
284+
self.num_channels,
285+
)
286+
return t_emb
287+
288+
257289
class GaussianFourierProjection(nn.Module):
258290
"""Gaussian Fourier embeddings for noise levels."""
259291

src/diffusers/models/unets/unet_2d.py

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...utils import BaseOutput
22-
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
22+
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps, TimestepsADM
2323
from ..modeling_utils import ModelMixin
24-
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
24+
from .unet_2d_blocks import UNetMidBlock2D, UNetMidBlock2DADM, get_down_block, get_up_block
2525

2626

2727
@dataclass
@@ -58,7 +58,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
5858
down_block_types (`Tuple[str]`, *optional*, defaults to `("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D")`):
5959
Tuple of downsample block types.
6060
mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2D"`):
61-
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UnCLIPUNetMidBlock2D`.
61+
Block type for middle of UNet, it can be either `UNetMidBlock2D` or `UNetMidBlock2DADM`.
6262
up_block_types (`Tuple[str]`, *optional*, defaults to `("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D")`):
6363
Tuple of upsample block types.
6464
block_out_channels (`Tuple[int]`, *optional*, defaults to `(224, 448, 672, 896)`):
@@ -72,6 +72,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
7272
The upsample type for upsampling layers. Choose between "conv" and "resnet"
7373
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
7474
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75+
attention_type (`str`, *optional*, defaults to `default`): The attention type, Choose between "default", "adm"
7576
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
7677
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
7778
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
@@ -100,6 +101,7 @@ def __init__(
100101
freq_shift: int = 0,
101102
flip_sin_to_cos: bool = True,
102103
down_block_types: Tuple[str] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
104+
mid_block_type: str = "UNetMidBlock2D",
103105
up_block_types: Tuple[str] = ("AttnUpBlock2D", "AttnUpBlock2D", "AttnUpBlock2D", "UpBlock2D"),
104106
block_out_channels: Tuple[int] = (224, 448, 672, 896),
105107
layers_per_block: int = 2,
@@ -109,6 +111,7 @@ def __init__(
109111
upsample_type: str = "conv",
110112
dropout: float = 0.0,
111113
act_fn: str = "silu",
114+
attention_type: str = "default",
112115
attention_head_dim: Optional[int] = 8,
113116
norm_num_groups: int = 32,
114117
attn_norm_num_groups: Optional[int] = None,
@@ -148,7 +151,9 @@ def __init__(
148151
elif time_embedding_type == "learned":
149152
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
150153
timestep_input_dim = block_out_channels[0]
151-
154+
elif time_embedding_type == "adm":
155+
self.time_proj = TimestepsADM(block_out_channels[0])
156+
timestep_input_dim = block_out_channels[0]
152157
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
153158

154159
# class embedding
@@ -182,6 +187,7 @@ def __init__(
182187
resnet_eps=norm_eps,
183188
resnet_act_fn=act_fn,
184189
resnet_groups=norm_num_groups,
190+
attention_type=attention_type,
185191
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
186192
downsample_padding=downsample_padding,
187193
resnet_time_scale_shift=resnet_time_scale_shift,
@@ -191,20 +197,34 @@ def __init__(
191197
self.down_blocks.append(down_block)
192198

193199
# mid
194-
self.mid_block = UNetMidBlock2D(
195-
in_channels=block_out_channels[-1],
196-
temb_channels=time_embed_dim,
197-
dropout=dropout,
198-
resnet_eps=norm_eps,
199-
resnet_act_fn=act_fn,
200-
output_scale_factor=mid_block_scale_factor,
201-
resnet_time_scale_shift=resnet_time_scale_shift,
202-
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
203-
resnet_groups=norm_num_groups,
204-
attn_groups=attn_norm_num_groups,
205-
add_attention=add_attention,
206-
)
207-
200+
if mid_block_type == "UNetMidBlock2D":
201+
self.mid_block = UNetMidBlock2D(
202+
in_channels=block_out_channels[-1],
203+
temb_channels=time_embed_dim,
204+
dropout=dropout,
205+
resnet_eps=norm_eps,
206+
resnet_act_fn=act_fn,
207+
output_scale_factor=mid_block_scale_factor,
208+
resnet_time_scale_shift=resnet_time_scale_shift,
209+
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
210+
resnet_groups=norm_num_groups,
211+
attn_groups=attn_norm_num_groups,
212+
add_attention=add_attention,
213+
)
214+
elif mid_block_type == "UNetMidBlock2DADM":
215+
self.mid_block = UNetMidBlock2DADM(
216+
in_channels=block_out_channels[-1],
217+
temb_channels=time_embed_dim,
218+
dropout=dropout,
219+
resnet_eps=norm_eps,
220+
resnet_act_fn=act_fn,
221+
output_scale_factor=mid_block_scale_factor,
222+
resnet_time_scale_shift=resnet_time_scale_shift,
223+
attention_head_dim=attention_head_dim if attention_head_dim is not None else block_out_channels[-1],
224+
resnet_groups=norm_num_groups,
225+
)
226+
else:
227+
raise ValueError
208228
# up
209229
reversed_block_out_channels = list(reversed(block_out_channels))
210230
output_channel = reversed_block_out_channels[0]
@@ -214,7 +234,6 @@ def __init__(
214234
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
215235

216236
is_final_block = i == len(block_out_channels) - 1
217-
218237
up_block = get_up_block(
219238
up_block_type,
220239
num_layers=layers_per_block + 1,
@@ -226,6 +245,7 @@ def __init__(
226245
resnet_eps=norm_eps,
227246
resnet_act_fn=act_fn,
228247
resnet_groups=norm_num_groups,
248+
attention_type=attention_type,
229249
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
230250
resnet_time_scale_shift=resnet_time_scale_shift,
231251
upsample_type=upsample_type,

0 commit comments

Comments
 (0)