From 111eac139f3dc6ff47c50810a35044cad9b323b1 Mon Sep 17 00:00:00 2001 From: tongda xu Date: Sat, 3 Feb 2024 19:37:42 +0800 Subject: [PATCH 1/2] code dev: support adm inference --- scripts/convert_adm_to_diffusers.py | 102 ++++++++++++++++++ src/diffusers/models/attention_processor.py | 20 +++- src/diffusers/models/embeddings.py | 32 ++++++ src/diffusers/models/unet_2d_blocks.py | 4 + src/diffusers/models/unets/unet_2d.py | 13 ++- src/diffusers/models/unets/unet_2d_blocks.py | 16 ++- .../versatile_diffusion/modeling_text_unet.py | 7 +- 7 files changed, 182 insertions(+), 12 deletions(-) create mode 100644 scripts/convert_adm_to_diffusers.py diff --git a/scripts/convert_adm_to_diffusers.py b/scripts/convert_adm_to_diffusers.py new file mode 100644 index 000000000000..97b60b96e101 --- /dev/null +++ b/scripts/convert_adm_to_diffusers.py @@ -0,0 +1,102 @@ +import argparse +import os + +import torch +from convert_consistency_to_diffusers import con_pt_to_diffuser + +from diffusers import ( + UNet2DModel, +) + + +SMALL_256_UNET_CONFIG = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 6, + "layers_per_block": 1, + "num_class_embeds": None, + "block_out_channels": [128, 128, 128 * 2, 128 * 2, 128 * 4, 128 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "ResnetDownsampleBlock2D", + ], + "up_block_types": [ + "ResnetUpsampleBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "upsample_type": "resnet", + "downsample_type": "resnet", + "norm_eps": 1e-06, + "norm_num_groups": 32, +} + + +LARGE_256_UNET_CONFIG = { + "sample_size": 256, + "in_channels": 3, + "out_channels": 6, + "layers_per_block": 2, + "num_class_embeds": None, + "block_out_channels": [256, 256, 256 * 2, 256 * 2, 256 * 4, 256 * 4], + "attention_head_dim": 64, + "down_block_types": [ + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "ResnetDownsampleBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + "AttnDownBlock2D", + ], + "up_block_types": [ + "AttnUpBlock2D", + "AttnUpBlock2D", + "AttnUpBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + "ResnetUpsampleBlock2D", + ], + "resnet_time_scale_shift": "scale_shift", + "upsample_type": "resnet", + "downsample_type": "resnet", + "norm_eps": 1e-06, + "norm_num_groups": 32, +} + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument("--unet_path", default=None, type=str, required=True, help="Path to the unet.pt to convert.") + parser.add_argument( + "--dump_path", default=None, type=str, required=True, help="Path to output the converted UNet model." + ) + + args = parser.parse_args() + + ckpt_name = os.path.basename(args.unet_path) + print(f"Checkpoint: {ckpt_name}") + + # Get U-Net config + if "ffhq" in ckpt_name: + unet_config = SMALL_256_UNET_CONFIG + else: + unet_config = LARGE_256_UNET_CONFIG + + unet_config["num_class_embeds"] = None + + converted_unet_ckpt = con_pt_to_diffuser(args.unet_path, unet_config) + + image_unet = UNet2DModel(**unet_config) + image_unet.load_state_dict(converted_unet_ckpt) + + torch.save(converted_unet_ckpt, args.dump_path) diff --git a/src/diffusers/models/attention_processor.py b/src/diffusers/models/attention_processor.py index 908946119dc2..3de6e5e4b0c3 100644 --- a/src/diffusers/models/attention_processor.py +++ b/src/diffusers/models/attention_processor.py @@ -84,6 +84,8 @@ class Attention(nn.Module): processor (`AttnProcessor`, *optional*, defaults to `None`): The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and `AttnProcessor` otherwise. + attention_legacy_order (`bool`, *optional*, defaults to `False`): + if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328 """ def __init__( @@ -110,6 +112,7 @@ def __init__( _from_deprecated_attn_block: bool = False, processor: Optional["AttnProcessor"] = None, out_dim: int = None, + attention_legacy_order: bool = False, ): super().__init__() self.inner_dim = out_dim if out_dim is not None else dim_head * heads @@ -205,6 +208,7 @@ def __init__( # We use the AttnProcessor2_0 by default when torch 2.x is used which uses # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + self.attention_legacy_order = attention_legacy_order if processor is None: processor = ( AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor() @@ -1221,6 +1225,7 @@ def __call__( hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) args = () if USE_PEFT_BACKEND else (scale,) + query = attn.to_q(hidden_states, *args) if encoder_hidden_states is None: @@ -1234,11 +1239,16 @@ def __call__( inner_dim = key.shape[-1] head_dim = inner_dim // attn.heads - query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - - key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) - + if attn.attention_legacy_order: + qkv = torch.cat([query, key, value], dim=2).transpose(1, 2) + query, key, value = qkv.reshape(batch_size, attn.heads, head_dim * 3, -1).chunk(3, dim=2) + query = query.transpose(-1, -2) + key = key.transpose(-1, -2) + value = value.transpose(-1, -2) + else: + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) # the output of sdp = (batch, num_heads, seq_len, head_dim) # TODO: add support for attn.scale when we move to Torch 2.1 hidden_states = F.scaled_dot_product_attention( diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index 1ef035af10c1..c9d6ad2e776e 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -254,6 +254,38 @@ def forward(self, timesteps): return t_emb +def timestep_embedding_adm(timesteps, dim, max_period=10000): + """ + ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103 + """ + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to( + device=timesteps.device + ) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepsADM(nn.Module): + """ + ADM order embedding from https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/nn.py#L103 + """ + + def __init__(self, num_channels: int): + super().__init__() + self.num_channels = num_channels + + def forward(self, timesteps): + t_emb = timestep_embedding_adm( + timesteps, + self.num_channels, + ) + return t_emb + + class GaussianFourierProjection(nn.Module): """Gaussian Fourier embeddings for noise levels.""" diff --git a/src/diffusers/models/unet_2d_blocks.py b/src/diffusers/models/unet_2d_blocks.py index 497eabfc607b..f983072674ca 100644 --- a/src/diffusers/models/unet_2d_blocks.py +++ b/src/diffusers/models/unet_2d_blocks.py @@ -65,6 +65,7 @@ def get_down_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", + attention_legacy_order: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -97,6 +98,7 @@ def get_down_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + attention_legacy_order=attention_legacy_order, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, @@ -202,6 +204,7 @@ def get_up_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", + attention_legacy_order: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -235,6 +238,7 @@ def get_up_block( upcast_attention=upcast_attention, resnet_time_scale_shift=resnet_time_scale_shift, attention_type=attention_type, + attention_legacy_order=attention_legacy_order, resnet_skip_time_act=resnet_skip_time_act, resnet_out_scale_factor=resnet_out_scale_factor, cross_attention_norm=cross_attention_norm, diff --git a/src/diffusers/models/unets/unet_2d.py b/src/diffusers/models/unets/unet_2d.py index 0a4ede51a7fd..202f54ea266d 100644 --- a/src/diffusers/models/unets/unet_2d.py +++ b/src/diffusers/models/unets/unet_2d.py @@ -19,7 +19,7 @@ from ...configuration_utils import ConfigMixin, register_to_config from ...utils import BaseOutput -from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps +from ..embeddings import GaussianFourierProjection, TimestepEmbedding, Timesteps, TimestepsADM from ..modeling_utils import ModelMixin from .unet_2d_blocks import UNetMidBlock2D, get_down_block, get_up_block @@ -72,6 +72,8 @@ class UNet2DModel(ModelMixin, ConfigMixin): The upsample type for upsampling layers. Choose between "conv" and "resnet" dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + attention_legacy_order (`bool`, *optional*, defaults to `False`): + if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328 attention_head_dim (`int`, *optional*, defaults to `8`): The attention head dimension. norm_num_groups (`int`, *optional*, defaults to `32`): The number of groups for normalization. attn_norm_num_groups (`int`, *optional*, defaults to `None`): @@ -109,6 +111,7 @@ def __init__( upsample_type: str = "conv", dropout: float = 0.0, act_fn: str = "silu", + attention_legacy_order: bool = False, attention_head_dim: Optional[int] = 8, norm_num_groups: int = 32, attn_norm_num_groups: Optional[int] = None, @@ -148,7 +151,9 @@ def __init__( elif time_embedding_type == "learned": self.time_proj = nn.Embedding(num_train_timesteps, block_out_channels[0]) timestep_input_dim = block_out_channels[0] - + elif time_embedding_type == "adm": + self.time_proj = TimestepsADM(block_out_channels[0]) + timestep_input_dim = block_out_channels[0] self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) # class embedding @@ -182,6 +187,7 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, + attention_legacy_order=attention_legacy_order, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, downsample_padding=downsample_padding, resnet_time_scale_shift=resnet_time_scale_shift, @@ -191,6 +197,7 @@ def __init__( self.down_blocks.append(down_block) # mid + attn_norm_num_groups = norm_num_groups if attention_legacy_order is True else attn_norm_num_groups self.mid_block = UNetMidBlock2D( in_channels=block_out_channels[-1], temb_channels=time_embed_dim, @@ -203,6 +210,7 @@ def __init__( resnet_groups=norm_num_groups, attn_groups=attn_norm_num_groups, add_attention=add_attention, + attention_legacy_order=attention_legacy_order, ) # up @@ -226,6 +234,7 @@ def __init__( resnet_eps=norm_eps, resnet_act_fn=act_fn, resnet_groups=norm_num_groups, + attention_legacy_order=attention_legacy_order, attention_head_dim=attention_head_dim if attention_head_dim is not None else output_channel, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 3796896ef675..0528c9fe7307 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -60,6 +60,7 @@ def get_down_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", + attention_legacy_order: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -73,7 +74,6 @@ def get_down_block( f"It is recommended to provide `attention_head_dim` when calling `get_down_block`. Defaulting `attention_head_dim` to {num_attention_heads}." ) attention_head_dim = num_attention_heads - down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type if down_block_type == "DownBlock2D": return DownBlock2D( @@ -119,6 +119,7 @@ def get_down_block( resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, downsample_padding=downsample_padding, + attention_legacy_order=attention_legacy_order, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, downsample_type=downsample_type, @@ -270,6 +271,7 @@ def get_up_block( upcast_attention: bool = False, resnet_time_scale_shift: str = "default", attention_type: str = "default", + attention_legacy_order: bool = False, resnet_skip_time_act: bool = False, resnet_out_scale_factor: float = 1.0, cross_attention_norm: Optional[str] = None, @@ -382,6 +384,7 @@ def get_up_block( resnet_eps=resnet_eps, resnet_act_fn=resnet_act_fn, resnet_groups=resnet_groups, + attention_legacy_order=attention_legacy_order, attention_head_dim=attention_head_dim, resnet_time_scale_shift=resnet_time_scale_shift, upsample_type=upsample_type, @@ -534,6 +537,8 @@ class UNetMidBlock2D(nn.Module): attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels. + attention_legacy_order (`bool`, *optional*, defaults to `False`): + if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328 output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. Returns: @@ -549,13 +554,14 @@ def __init__( dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial + resnet_time_scale_shift: str = "default", # default, spatial, scale_shift resnet_act_fn: str = "swish", resnet_groups: int = 32, attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, + attention_legacy_order: bool = False, output_scale_factor: float = 1.0, ): super().__init__() @@ -602,7 +608,6 @@ def __init__( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels - for _ in range(num_layers): if self.add_attention: attentions.append( @@ -618,6 +623,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + attention_legacy_order=attention_legacy_order, ) ) else: @@ -950,6 +956,7 @@ def __init__( resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, + attention_legacy_order: bool = False, output_scale_factor: float = 1.0, downsample_padding: int = 1, downsample_type: str = "conv", @@ -993,6 +1000,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + attention_legacy_order=attention_legacy_order, ) ) @@ -2159,6 +2167,7 @@ def __init__( resnet_groups: int = 32, resnet_pre_norm: bool = True, attention_head_dim: int = 1, + attention_legacy_order: bool = False, output_scale_factor: float = 1.0, upsample_type: str = "conv", ): @@ -2204,6 +2213,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + attention_legacy_order=attention_legacy_order, ) ) diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 60707cc1e2f7..9b432e5fbade 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -2095,6 +2095,8 @@ class UNetMidBlockFlat(nn.Module): attention_head_dim (`int`, *optional*, defaults to 1): Dimension of a single attention head. The number of attention heads is determined based on this value and the number of input channels. + attention_legacy_order (`bool`, *optional*, defaults to `False`): + if attention_legacy_order, split heads before split qkv, see https://github.com/openai/guided-diffusion/blob/main/guided_diffusion/unet.py#L328 output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor. Returns: @@ -2110,13 +2112,14 @@ def __init__( dropout: float = 0.0, num_layers: int = 1, resnet_eps: float = 1e-6, - resnet_time_scale_shift: str = "default", # default, spatial + resnet_time_scale_shift: str = "default", # default, spatial, scale_shift resnet_act_fn: str = "swish", resnet_groups: int = 32, attn_groups: Optional[int] = None, resnet_pre_norm: bool = True, add_attention: bool = True, attention_head_dim: int = 1, + attention_legacy_order: bool = False, output_scale_factor: float = 1.0, ): super().__init__() @@ -2163,7 +2166,6 @@ def __init__( f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}." ) attention_head_dim = in_channels - for _ in range(num_layers): if self.add_attention: attentions.append( @@ -2179,6 +2181,7 @@ def __init__( bias=True, upcast_softmax=True, _from_deprecated_attn_block=True, + attention_legacy_order=attention_legacy_order, ) ) else: From f817d9ad082e4f89eaf55cd3df5be89c0d57637e Mon Sep 17 00:00:00 2001 From: "T. Xu" Date: Mon, 5 Feb 2024 11:44:34 +0800 Subject: [PATCH 2/2] Update src/diffusers/models/embeddings.py Co-authored-by: Sayak Paul --- src/diffusers/models/embeddings.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/src/diffusers/models/embeddings.py b/src/diffusers/models/embeddings.py index c9d6ad2e776e..9671a7b8e887 100644 --- a/src/diffusers/models/embeddings.py +++ b/src/diffusers/models/embeddings.py @@ -279,10 +279,7 @@ def __init__(self, num_channels: int): self.num_channels = num_channels def forward(self, timesteps): - t_emb = timestep_embedding_adm( - timesteps, - self.num_channels, - ) + t_emb = timestep_embedding_adm(timesteps, self.num_channels) return t_emb