|
| 1 | +import torch |
| 2 | +from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig |
| 3 | +import torch.nn as nn |
| 4 | +import torch.nn.functional as F |
| 5 | +from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel |
| 6 | +from diffusers.models.embeddings import get_timestep_embedding |
| 7 | +from ...utils import replace_unet_conv_in, replace_attention_mask_method, add_aux_conv_in |
| 8 | +from ...utils.replace import CustomUNet |
| 9 | +import random |
| 10 | +import os |
| 11 | + |
| 12 | +# 解决离线本地目录层级差异,如存在 "subdir/subdir/config.json" 的情况 |
| 13 | +def _resolve_nested_dir(base_dir: str, subdir: str, config_filename: str) -> str: |
| 14 | + direct = os.path.join(base_dir, subdir) |
| 15 | + nested = os.path.join(base_dir, subdir, subdir) |
| 16 | + if os.path.exists(os.path.join(direct, config_filename)): |
| 17 | + return direct |
| 18 | + if os.path.exists(os.path.join(nested, config_filename)): |
| 19 | + return nested |
| 20 | + return direct |
| 21 | + |
| 22 | +AUX_INPUT_DIT = { |
| 23 | + "auto_mask": "auto_coords", |
| 24 | + "point_mask": "point_coords", |
| 25 | + "bbox_mask": "bbox_coords", |
| 26 | + "mask": "mask_coords", |
| 27 | + "trimap": "trimap_coords", |
| 28 | +} |
| 29 | + |
| 30 | +class SDMatte(nn.Module): |
| 31 | + def __init__( |
| 32 | + self, |
| 33 | + pretrained_model_name_or_path, |
| 34 | + conv_scale=3, |
| 35 | + num_inference_steps=1, |
| 36 | + aux_input="bbox_mask", |
| 37 | + use_aux_input=False, |
| 38 | + use_coor_input=True, |
| 39 | + use_dis_loss=True, |
| 40 | + use_attention_mask=True, |
| 41 | + use_encoder_attention_mask=False, |
| 42 | + add_noise=False, |
| 43 | + attn_mask_aux_input=["point_mask", "bbox_mask", "mask"], |
| 44 | + aux_input_list=["point_mask", "bbox_mask", "mask"], |
| 45 | + use_encoder_hidden_states=True, |
| 46 | + residual_connection=False, |
| 47 | + use_attention_mask_list=[True, True, True], |
| 48 | + use_encoder_hidden_states_list=[True, True, True], |
| 49 | + load_weight = True, |
| 50 | + ): |
| 51 | + super().__init__() |
| 52 | + self.init_submodule(pretrained_model_name_or_path, load_weight) |
| 53 | + self.num_inference_steps = num_inference_steps |
| 54 | + self.aux_input = aux_input |
| 55 | + self.use_aux_input = use_aux_input |
| 56 | + self.use_coor_input = use_coor_input |
| 57 | + self.use_dis_loss = use_dis_loss |
| 58 | + self.use_attention_mask = use_attention_mask |
| 59 | + self.use_encoder_attention_mask = use_encoder_attention_mask |
| 60 | + self.add_noise = add_noise |
| 61 | + self.attn_mask_aux_input = attn_mask_aux_input |
| 62 | + self.aux_input_list = aux_input_list |
| 63 | + self.use_encoder_hidden_states = use_encoder_hidden_states |
| 64 | + if use_encoder_hidden_states: |
| 65 | + self.unet = add_aux_conv_in(self.unet) |
| 66 | + if not add_noise: |
| 67 | + conv_scale -= 1 |
| 68 | + if not use_aux_input: |
| 69 | + conv_scale -= 1 |
| 70 | + if conv_scale > 1: |
| 71 | + self.unet = replace_unet_conv_in(self.unet, conv_scale) |
| 72 | + replace_attention_mask_method(self.unet, residual_connection) |
| 73 | + self.text_encoder.requires_grad_(False) |
| 74 | + self.vae.requires_grad_(False) |
| 75 | + self.unet.train() |
| 76 | + self.unet.use_attention_mask_list = use_attention_mask_list |
| 77 | + self.unet.use_encoder_hidden_states_list = use_encoder_hidden_states_list |
| 78 | + |
| 79 | + def init_submodule(self, pretrained_model_name_or_path, load_weight): |
| 80 | + if load_weight: |
| 81 | + text_dir = _resolve_nested_dir(pretrained_model_name_or_path, "text_encoder", "config.json") |
| 82 | + vae_dir = _resolve_nested_dir(pretrained_model_name_or_path, "vae", "config.json") |
| 83 | + unet_dir = _resolve_nested_dir(pretrained_model_name_or_path, "unet", "config.json") |
| 84 | + sched_dir = _resolve_nested_dir(pretrained_model_name_or_path, "scheduler", "scheduler_config.json") |
| 85 | + tok_dir = _resolve_nested_dir(pretrained_model_name_or_path, "tokenizer", "tokenizer_config.json") |
| 86 | + |
| 87 | + self.text_encoder = CLIPTextModel.from_pretrained(text_dir) |
| 88 | + self.vae = AutoencoderKL.from_pretrained(vae_dir) |
| 89 | + self.unet = CustomUNet.from_pretrained( |
| 90 | + unet_dir, low_cpu_mem_usage=True, ignore_mismatched_sizes=False |
| 91 | + ) |
| 92 | + self.noise_scheduler = DDIMScheduler.from_pretrained(sched_dir) |
| 93 | + self.tokenizer = CLIPTokenizer.from_pretrained(tok_dir) |
| 94 | + else: |
| 95 | + text_dir = _resolve_nested_dir(pretrained_model_name_or_path, "text_encoder", "config.json") |
| 96 | + text_config = CLIPTextConfig.from_pretrained(text_dir) |
| 97 | + self.text_encoder = CLIPTextModel(text_config) |
| 98 | + |
| 99 | + vae_path = _resolve_nested_dir(pretrained_model_name_or_path, "vae", "config.json") |
| 100 | + self.vae = AutoencoderKL.from_config(AutoencoderKL.load_config(vae_path)) |
| 101 | + |
| 102 | + unet_path = _resolve_nested_dir(pretrained_model_name_or_path, "unet", "config.json") |
| 103 | + self.unet = CustomUNet.from_config( |
| 104 | + CustomUNet.load_config(unet_path), |
| 105 | + low_cpu_mem_usage=True, |
| 106 | + ignore_mismatched_sizes=False |
| 107 | + ) |
| 108 | + |
| 109 | + scheduler_path = os.path.join(_resolve_nested_dir(pretrained_model_name_or_path, "scheduler", "scheduler_config.json"), "scheduler_config.json") |
| 110 | + self.noise_scheduler = DDIMScheduler.from_config(DDIMScheduler.load_config(scheduler_path)) |
| 111 | + |
| 112 | + tok_dir = _resolve_nested_dir(pretrained_model_name_or_path, "tokenizer", "tokenizer_config.json") |
| 113 | + self.tokenizer = CLIPTokenizer.from_pretrained(tok_dir) |
| 114 | + |
| 115 | + |
| 116 | + def forward(self, data): |
| 117 | + rgb = data["image"].cuda() |
| 118 | + B = rgb.shape[0] |
| 119 | + |
| 120 | + if self.aux_input is None and self.training: |
| 121 | + aux_input_type = random.choice(self.aux_input_list) |
| 122 | + elif self.aux_input is None: |
| 123 | + aux_input_type = "point_mask" |
| 124 | + else: |
| 125 | + aux_input_type = self.aux_input |
| 126 | + |
| 127 | + # get aux input latent |
| 128 | + if self.use_aux_input: |
| 129 | + aux_input = data[aux_input_type].cuda() |
| 130 | + aux_input = aux_input.repeat(1, 3, 1, 1) |
| 131 | + aux_input_h = self.vae.encoder(aux_input.to(rgb.dtype)) |
| 132 | + aux_input_moments = self.vae.quant_conv(aux_input_h) |
| 133 | + aux_input_mean, _ = torch.chunk(aux_input_moments, 2, dim=1) |
| 134 | + aux_input_latent = aux_input_mean * self.vae.config.scaling_factor |
| 135 | + else: |
| 136 | + aux_input_latent = None |
| 137 | + |
| 138 | + # get aux coordinate |
| 139 | + coor_name = AUX_INPUT_DIT[aux_input_type] |
| 140 | + coor = data[coor_name].cuda() |
| 141 | + if coor_name == "point_coords": |
| 142 | + N = coor.shape[1] |
| 143 | + for i in range(N, 1680): |
| 144 | + if 1680 % i == 0: |
| 145 | + num_channels = 1680 // i |
| 146 | + pad_size = i - N |
| 147 | + padding = torch.zeros((B, pad_size), dtype=coor.dtype, device=coor.device) |
| 148 | + coor = torch.cat([coor, padding], dim=1) |
| 149 | + zero_coor = torch.zeros((B, pad_size + N), dtype=coor.dtype, device=coor.device) |
| 150 | + break |
| 151 | + if self.use_coor_input: |
| 152 | + coor = get_timestep_embedding( |
| 153 | + coor.flatten(), |
| 154 | + num_channels, |
| 155 | + flip_sin_to_cos=True, |
| 156 | + downscale_freq_shift=0, |
| 157 | + ) |
| 158 | + else: |
| 159 | + coor = get_timestep_embedding( |
| 160 | + zero_coor.flatten(), |
| 161 | + num_channels, |
| 162 | + flip_sin_to_cos=True, |
| 163 | + downscale_freq_shift=0, |
| 164 | + ) |
| 165 | + added_cond_kwargs = {"point_coords": coor} |
| 166 | + else: |
| 167 | + if self.use_coor_input: |
| 168 | + added_cond_kwargs = {"bbox_mask_coords": coor} |
| 169 | + else: |
| 170 | + coor = torch.tensor([[0, 0, 1, 1]] * B).cuda() |
| 171 | + added_cond_kwargs = {"bbox_mask_coords": coor} |
| 172 | + |
| 173 | + # get attention mask |
| 174 | + if self.use_attention_mask and aux_input_type in self.attn_mask_aux_input: |
| 175 | + attention_mask = data[aux_input_type].cuda() |
| 176 | + attention_mask = (attention_mask + 1) / 2 |
| 177 | + attention_mask = F.interpolate(attention_mask, scale_factor=1 / 8, mode="nearest") |
| 178 | + attention_mask = attention_mask.flatten(start_dim=1) |
| 179 | + else: |
| 180 | + attention_mask = None |
| 181 | + |
| 182 | + # encode rgb to latents |
| 183 | + rgb_h = self.vae.encoder(rgb) |
| 184 | + rgb_moments = self.vae.quant_conv(rgb_h) |
| 185 | + rgb_mean, _ = torch.chunk(rgb_moments, 2, dim=1) |
| 186 | + rgb_latent = rgb_mean * self.vae.config.scaling_factor |
| 187 | + |
| 188 | + # get encoder_hidden_states |
| 189 | + if self.use_encoder_hidden_states and aux_input_latent is not None: |
| 190 | + encoder_hidden_states = self.unet.aux_conv_in(aux_input_latent) |
| 191 | + encoder_hidden_states = encoder_hidden_states.view(B, 1024, -1) |
| 192 | + encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1) |
| 193 | + |
| 194 | + if "caption" in data: |
| 195 | + prompt = data["caption"] |
| 196 | + else: |
| 197 | + prompt = [""] * B |
| 198 | + prompt = [prompt] if isinstance(prompt, str) else prompt |
| 199 | + text_inputs = self.tokenizer( |
| 200 | + prompt, |
| 201 | + padding="max_length", |
| 202 | + max_length=self.tokenizer.model_max_length, |
| 203 | + truncation=True, |
| 204 | + return_tensors="pt", |
| 205 | + ) |
| 206 | + text_input_ids = text_inputs.input_ids.to("cuda") |
| 207 | + text_embed = self.text_encoder(text_input_ids)[0] |
| 208 | + encoder_hidden_states_2 = text_embed |
| 209 | + |
| 210 | + # get class_label |
| 211 | + is_trans = data["is_trans"].cuda() |
| 212 | + trans = 1 - is_trans |
| 213 | + |
| 214 | + # get timesteps |
| 215 | + timestep = torch.tensor([1], device="cuda").long() |
| 216 | + |
| 217 | + # unet |
| 218 | + unet_input = torch.cat([rgb_latent, aux_input_latent], dim=1) |
| 219 | + label_latent = self.unet( |
| 220 | + sample=unet_input, |
| 221 | + trans=trans, |
| 222 | + timestep=None, |
| 223 | + encoder_hidden_states=encoder_hidden_states, |
| 224 | + encoder_hidden_states_2=encoder_hidden_states_2, |
| 225 | + added_cond_kwargs=added_cond_kwargs, |
| 226 | + attention_mask=attention_mask, |
| 227 | + ).sample |
| 228 | + label_latent = label_latent / self.vae.config.scaling_factor |
| 229 | + z = self.vae.post_quant_conv(label_latent) |
| 230 | + stacked = self.vae.decoder(z) |
| 231 | + # mean of output channels |
| 232 | + label_mean = stacked.mean(dim=1, keepdim=True) |
| 233 | + output = torch.clip(label_mean, -1.0, 1.0) |
| 234 | + output = (output + 1.0) / 2.0 |
| 235 | + return output |
0 commit comments