Skip to content

Commit 7518788

Browse files
committed
add openai/guideddiffusion support
1 parent 8581d9b commit 7518788

File tree

6 files changed

+177
-11
lines changed

6 files changed

+177
-11
lines changed
Lines changed: 102 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import argparse
2+
import os
3+
4+
import torch
5+
from convert_consistency_to_diffusers import con_pt_to_diffuser
6+
7+
from diffusers import (
8+
UNet2DModel,
9+
)
10+
11+
12+
SMALL_256_UNET_CONFIG = {
13+
"sample_size": 256,
14+
"in_channels": 3,
15+
"out_channels": 6,
16+
"layers_per_block": 1,
17+
"num_class_embeds": None,
18+
"block_out_channels": [128, 128, 128 * 2, 128 * 2, 128 * 4, 128 * 4],
19+
"attention_head_dim": 64,
20+
"down_block_types": [
21+
"ResnetDownsampleBlock2D",
22+
"ResnetDownsampleBlock2D",
23+
"ResnetDownsampleBlock2D",
24+
"ResnetDownsampleBlock2D",
25+
"AttnDownBlock2D",
26+
"ResnetDownsampleBlock2D",
27+
],
28+
"up_block_types": [
29+
"ResnetUpsampleBlock2D",
30+
"AttnUpBlock2D",
31+
"ResnetUpsampleBlock2D",
32+
"ResnetUpsampleBlock2D",
33+
"ResnetUpsampleBlock2D",
34+
"ResnetUpsampleBlock2D",
35+
],
36+
"resnet_time_scale_shift": "scale_shift",
37+
"upsample_type": "resnet",
38+
"downsample_type": "resnet",
39+
"norm_eps": 1e-06,
40+
"norm_num_groups": 32,
41+
}
42+
43+
44+
LARGE_256_UNET_CONFIG = {
45+
"sample_size": 256,
46+
"in_channels": 3,
47+
"out_channels": 6,
48+
"layers_per_block": 2,
49+
"num_class_embeds": None,
50+
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
51+
"attention_head_dim": 64,
52+
"down_block_types": [
53+
"ResnetDownsampleBlock2D",
54+
"ResnetDownsampleBlock2D",
55+
"ResnetDownsampleBlock2D",
56+
"AttnDownBlock2D",
57+
"AttnDownBlock2D",
58+
"AttnDownBlock2D",
59+
],
60+
"up_block_types": [
61+
"AttnUpBlock2D",
62+
"AttnUpBlock2D",
63+
"AttnUpBlock2D",
64+
"ResnetUpsampleBlock2D",
65+
"ResnetUpsampleBlock2D",
66+
"ResnetUpsampleBlock2D",
67+
],
68+
"resnet_time_scale_shift": "scale_shift",
69+
"upsample_type": "resnet",
70+
"downsample_type": "resnet",
71+
"norm_eps": 1e-06,
72+
"norm_num_groups": 32,
73+
}
74+
75+
76+
if __name__ == "__main__":
77+
parser = argparse.ArgumentParser()
78+
79+
parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
80+
parser.add_argument(
81+
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
82+
)
83+
84+
args = parser.parse_args()
85+
86+
ckpt_name = os.path.basename(args.unet_path)
87+
print(f"Checkpoint: {ckpt_name}")
88+
89+
# Get U-Net config
90+
if "ffhq" in ckpt_name:
91+
unet_config = SMALL_256_UNET_CONFIG
92+
else:
93+
unet_config = LARGE_256_UNET_CONFIG
94+
95+
unet_config["num_class_embeds"] = None
96+
97+
converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)
98+
99+
image_unet = UNet2DModel(**unet_config)
100+
image_unet.load_state_dict(converted_unet_ckpt)
101+
102+
torch.save(converted_unet_ckpt, args.dump_path)

src/diffusers/models/attention_processor.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class Attention(nn.Module):
8484
processor (`AttnProcessor`, *optional*, defaults to `None`):
8585
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
8686
`AttnProcessor` otherwise.
87+
attention_legacy_order (`bool`, *optional*, defaults to `False`):
88+
if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
8789
"""
8890

8991
def __init__(
@@ -110,6 +112,7 @@ def __init__(
110112
_from_deprecated_attn_block: bool = False,
111113
processor: Optional["AttnProcessor"] = None,
112114
out_dim: int = None,
115+
attention_legacy_order: bool = False,
113116
):
114117
super().__init__()
115118
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
@@ -205,6 +208,7 @@ def __init__(
205208
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
206209
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
207210
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
211+
self.attention_legacy_order = attention_legacy_order
208212
if processor is None:
209213
processor = (
210214
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
@@ -1221,6 +1225,7 @@ def __call__(
12211225
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
12221226

12231227
args = () if USE_PEFT_BACKEND else (scale,)
1228+
12241229
query = attn.to_q(hidden_states, *args)
12251230

12261231
if encoder_hidden_states is None:
@@ -1234,11 +1239,16 @@ def __call__(
12341239
inner_dim = key.shape[-1]
12351240
head_dim = inner_dim // attn.heads
12361241

1237-
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1238-
1239-
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1240-
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1241-
1242+
if attn.attention_legacy_order:
1243+
qkv = torch.cat([query, key, value], dim=2).transpose(1, 2)
1244+
query, key, value = qkv.reshape(batch_size, attn.heads, head_dim * 3, -1).chunk(3, dim=2)
1245+
query = query.transpose(-1, -2)
1246+
key = key.transpose(-1, -2)
1247+
value = value.transpose(-1, -2)
1248+
else:
1249+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1250+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1251+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
12421252
# the output of sdp = (batch, num_heads, seq_len, head_dim)
12431253
# TODO: add support for attn.scale when we move to Torch 2.1
12441254
hidden_states = F.scaled_dot_product_attention(

src/diffusers/models/embeddings.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,38 @@ def forward(self, timesteps):
254254
return t_emb
255255

256256

257+
def timestep_embedding_adm(timesteps, dim, max_period=10000):
258+
"""
259+
ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103
260+
"""
261+
half = dim // 2
262+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
263+
device=timesteps.device
264+
)
265+
args = timesteps[:, None].float() * freqs[None]
266+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
267+
if dim % 2:
268+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
269+
return embedding
270+
271+
272+
class TimestepsADM(nn.Module):
273+
"""
274+
ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103
275+
"""
276+
277+
def __init__(self, num_channels: int):
278+
super().__init__()
279+
self.num_channels = num_channels
280+
281+
def forward(self, timesteps):
282+
t_emb = timestep_embedding_adm(
283+
timesteps,
284+
self.num_channels,
285+
)
286+
return t_emb
287+
288+
257289
class GaussianFourierProjection(nn.Module):
258290
"""Gaussian Fourier embeddings for noise levels."""
259291

src/diffusers/models/unet_2d_blocks.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def get_down_block(
6565
upcast_attention: bool = False,
6666
resnet_time_scale_shift: str = "default",
6767
attention_type: str = "default",
68+
attention_legacy_order: bool = False,
6869
resnet_skip_time_act: bool = False,
6970
resnet_out_scale_factor: float = 1.0,
7071
cross_attention_norm: Optional[str] = None,
@@ -97,6 +98,7 @@ def get_down_block(
9798
upcast_attention=upcast_attention,
9899
resnet_time_scale_shift=resnet_time_scale_shift,
99100
attention_type=attention_type,
101+
attention_legacy_order=attention_legacy_order,
100102
resnet_skip_time_act=resnet_skip_time_act,
101103
resnet_out_scale_factor=resnet_out_scale_factor,
102104
cross_attention_norm=cross_attention_norm,
@@ -202,6 +204,7 @@ def get_up_block(
202204
upcast_attention: bool = False,
203205
resnet_time_scale_shift: str = "default",
204206
attention_type: str = "default",
207+
attention_legacy_order: bool = False,
205208
resnet_skip_time_act: bool = False,
206209
resnet_out_scale_factor: float = 1.0,
207210
cross_attention_norm: Optional[str] = None,
@@ -235,6 +238,7 @@ def get_up_block(
235238
upcast_attention=upcast_attention,
236239
resnet_time_scale_shift=resnet_time_scale_shift,
237240
attention_type=attention_type,
241+
attention_legacy_order=attention_legacy_order,
238242
resnet_skip_time_act=resnet_skip_time_act,
239243
resnet_out_scale_factor=resnet_out_scale_factor,
240244
cross_attention_norm=cross_attention_norm,

src/diffusers/models/unets/unet_2d.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ...configuration_utils import ConfigMixin, register_to_config
2121
from ...utils import BaseOutput
22-
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
22+
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps, TimestepsADM
2323
from ..modeling_utils import ModelMixin
2424
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block
2525

@@ -72,6 +72,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
7272
The upsample type for upsampling layers. Choose between "conv" and "resnet"
7373
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
7474
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
75+
attention_legacy_order (`bool`, *optional*, defaults to `False`):
76+
if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
7577
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
7678
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
7779
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
@@ -109,6 +111,7 @@ def __init__(
109111
upsample_type: str = "conv",
110112
dropout: float = 0.0,
111113
act_fn: str = "silu",
114+
attention_legacy_order: bool = False,
112115
attention_head_dim: Optional[int] = 8,
113116
norm_num_groups: int = 32,
114117
attn_norm_num_groups: Optional[int] = None,
@@ -148,7 +151,9 @@ def __init__(
148151
elif time_embedding_type == "learned":
149152
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
150153
timestep_input_dim = block_out_channels[0]
151-
154+
elif time_embedding_type == "adm":
155+
self.time_proj = TimestepsADM(block_out_channels[0])
156+
timestep_input_dim = block_out_channels[0]
152157
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
153158

154159
# class embedding
@@ -182,6 +187,7 @@ def __init__(
182187
resnet_eps=norm_eps,
183188
resnet_act_fn=act_fn,
184189
resnet_groups=norm_num_groups,
190+
attention_legacy_order=attention_legacy_order,
185191
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
186192
downsample_padding=downsample_padding,
187193
resnet_time_scale_shift=resnet_time_scale_shift,
@@ -203,6 +209,7 @@ def __init__(
203209
resnet_groups=norm_num_groups,
204210
attn_groups=attn_norm_num_groups,
205211
add_attention=add_attention,
212+
attention_legacy_order=attention_legacy_order,
206213
)
207214

208215
# up
@@ -226,6 +233,7 @@ def __init__(
226233
resnet_eps=norm_eps,
227234
resnet_act_fn=act_fn,
228235
resnet_groups=norm_num_groups,
236+
attention_legacy_order=attention_legacy_order,
229237
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
230238
resnet_time_scale_shift=resnet_time_scale_shift,
231239
upsample_type=upsample_type,

0 commit comments

Comments
 (0)