Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 102 additions & 0 deletions scripts/convert_adm_to_diffusers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import argparse
import os

import torch
from convert_consistency_to_diffusers import con_pt_to_diffuser

from diffusers import (
UNet2DModel,
)


SMALL_256_UNET_CONFIG = {
"sample_size": 256,
"in_channels": 3,
"out_channels": 6,
"layers_per_block": 1,
"num_class_embeds": None,
"block_out_channels": [128, 128, 128 * 2, 128 * 2, 128 * 4, 128 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"ResnetDownsampleBlock2D",
],
"up_block_types": [
"ResnetUpsampleBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you show me the codepath that becomes effective when resnet_time_scale_shift == "scale_shift"?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The convert_adm_to_diffusers.py is just calling functions of con_pt_to_diffuser function in convert_consistency_to_diffusers.py.

First, the resnet_time_scale_shift == "scale_shift" is not a new option set by this PR.

Setting resnet_time_scale_shift == "scale_shift" will pass the argument though UNet2DModel. Through init of UNet2DModel, it will be pass into get_down_block, UNetMidBlock2D and get_up_block, and will be further passed into each building blocks such as ResnetDownsampleBlock2D, ResnetUpsampleBlock2D, AttnDownBlock2D, AttnUpBlock2D. The eventual effect of resnet_time_scale_shift == "scale_shift" will set the class ResnetBlock2D's time_embedding_norm == "scale_shift". And this option effects the resnet's time embedding's shape https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py#L283, and the behaviour of time embedding https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py#L378.

It has no special effect on UNetMidBlock2D, as I circumvent this problem in https://github.com/tongdaxu/diffusers/blob/main/src/diffusers/models/unets/unet_2d.py#L200.

The resnet_time_scale_shift == "scale_shift" is necessary in model conversion script as the resnet's time embedding's input shape is doubled with resnet_time_scale_shift == "scale_shift".

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks!

"upsample_type": "resnet",
"downsample_type": "resnet",
"norm_eps": 1e-06,
"norm_num_groups": 32,
}


LARGE_256_UNET_CONFIG = {
"sample_size": 256,
"in_channels": 3,
"out_channels": 6,
"layers_per_block": 2,
"num_class_embeds": None,
"block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4],
"attention_head_dim": 64,
"down_block_types": [
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"ResnetDownsampleBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
"AttnDownBlock2D",
],
"up_block_types": [
"AttnUpBlock2D",
"AttnUpBlock2D",
"AttnUpBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
"ResnetUpsampleBlock2D",
],
"resnet_time_scale_shift": "scale_shift",
"upsample_type": "resnet",
"downsample_type": "resnet",
"norm_eps": 1e-06,
"norm_num_groups": 32,
}


if __name__ == "__main__":
parser = argparse.ArgumentParser()

parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.")
parser.add_argument(
"--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model."
)

args = parser.parse_args()

ckpt_name = os.path.basename(args.unet_path)
print(f"Checkpoint: {ckpt_name}")

# Get U-Net config
if "ffhq" in ckpt_name:
unet_config = SMALL_256_UNET_CONFIG
else:
unet_config = LARGE_256_UNET_CONFIG

unet_config["num_class_embeds"] = None

converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config)

image_unet = UNet2DModel(**unet_config)
image_unet.load_state_dict(converted_unet_ckpt)

torch.save(converted_unet_ckpt, args.dump_path)
20 changes: 15 additions & 5 deletions src/diffusers/models/attention_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ class Attention(nn.Module):
processor (`AttnProcessor`, *optional*, defaults to `None`):
The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and
`AttnProcessor` otherwise.
attention_legacy_order (`bool`, *optional*, defaults to `False`):
if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
"""

def __init__(
Expand All @@ -110,6 +112,7 @@ def __init__(
_from_deprecated_attn_block: bool = False,
processor: Optional["AttnProcessor"] = None,
out_dim: int = None,
attention_legacy_order: bool = False,
):
super().__init__()
self.inner_dim = out_dim if out_dim is not None else dim_head * heads
Expand Down Expand Up @@ -205,6 +208,7 @@ def __init__(
# We use the AttnProcessor2_0 by default when torch 2.x is used which uses
# torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
# but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1
self.attention_legacy_order = attention_legacy_order
if processor is None:
processor = (
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
Expand Down Expand Up @@ -1221,6 +1225,7 @@ def __call__(
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

args = () if USE_PEFT_BACKEND else (scale,)

query = attn.to_q(hidden_states, *args)

if encoder_hidden_states is None:
Expand All @@ -1234,11 +1239,16 @@ def __call__(
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads

query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)

if attn.attention_legacy_order:
qkv = torch.cat([query, key, value], dim=2).transpose(1, 2)
query, key, value = qkv.reshape(batch_size, attn.heads, head_dim * 3, -1).chunk(3, dim=2)
query = query.transpose(-1, -2)
key = key.transpose(-1, -2)
value = value.transpose(-1, -2)
else:
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
Expand Down
32 changes: 32 additions & 0 deletions src/diffusers/models/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,38 @@ def forward(self, timesteps):
return t_emb


def timestep_embedding_adm(timesteps, dim, max_period=10000):
"""
ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103
"""
half = dim // 2
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(
device=timesteps.device
)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding


class TimestepsADM(nn.Module):
"""
ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103
"""

def __init__(self, num_channels: int):
super().__init__()
self.num_channels = num_channels

def forward(self, timesteps):
t_emb = timestep_embedding_adm(
timesteps,
self.num_channels,
)
return t_emb


class GaussianFourierProjection(nn.Module):
"""Gaussian Fourier embeddings for noise levels."""

Expand Down
4 changes: 4 additions & 0 deletions src/diffusers/models/unet_2d_blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ def get_down_block(
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
attention_legacy_order: bool = False,
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
cross_attention_norm: Optional[str] = None,
Expand Down Expand Up @@ -97,6 +98,7 @@ def get_down_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
attention_legacy_order=attention_legacy_order,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
Expand Down Expand Up @@ -202,6 +204,7 @@ def get_up_block(
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
attention_type: str = "default",
attention_legacy_order: bool = False,
resnet_skip_time_act: bool = False,
resnet_out_scale_factor: float = 1.0,
cross_attention_norm: Optional[str] = None,
Expand Down Expand Up @@ -235,6 +238,7 @@ def get_up_block(
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
attention_type=attention_type,
attention_legacy_order=attention_legacy_order,
resnet_skip_time_act=resnet_skip_time_act,
resnet_out_scale_factor=resnet_out_scale_factor,
cross_attention_norm=cross_attention_norm,
Expand Down
13 changes: 11 additions & 2 deletions src/diffusers/models/unets/unet_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from ...configuration_utils import ConfigMixin, register_to_config
from ...utils import BaseOutput
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps
from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps, TimestepsADM
from ..modeling_utils import ModelMixin
from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block

Expand Down Expand Up @@ -72,6 +72,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
The upsample type for upsampling layers. Choose between "conv" and "resnet"
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
attention_legacy_order (`bool`, *optional*, defaults to `False`):
if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328
attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension.
norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization.
attn_norm_num_groups (`int`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -109,6 +111,7 @@ def __init__(
upsample_type: str = "conv",
dropout: float = 0.0,
act_fn: str = "silu",
attention_legacy_order: bool = False,
attention_head_dim: Optional[int] = 8,
norm_num_groups: int = 32,
attn_norm_num_groups: Optional[int] = None,
Expand Down Expand Up @@ -148,7 +151,9 @@ def __init__(
elif time_embedding_type == "learned":
self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0])
timestep_input_dim = block_out_channels[0]

elif time_embedding_type == "adm":
self.time_proj = TimestepsADM(block_out_channels[0])
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)

# class embedding
Expand Down Expand Up @@ -182,6 +187,7 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_legacy_order=attention_legacy_order,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
Expand All @@ -191,6 +197,7 @@ def __init__(
self.down_blocks.append(down_block)

# mid
attn_norm_num_groups = norm_num_groups if attention_legacy_order is True else attn_norm_num_groups
self.mid_block = UNetMidBlock2D(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
Expand All @@ -203,6 +210,7 @@ def __init__(
resnet_groups=norm_num_groups,
attn_groups=attn_norm_num_groups,
add_attention=add_attention,
attention_legacy_order=attention_legacy_order,
)

# up
Expand All @@ -226,6 +234,7 @@ def __init__(
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
attention_legacy_order=attention_legacy_order,
attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel,
resnet_time_scale_shift=resnet_time_scale_shift,
upsample_type=upsample_type,
Expand Down
Loading