Skip to content

Commit 1e0bfa0

Browse files
authored
Add files via upload
1 parent 25c07e9 commit 1e0bfa0

File tree

7 files changed

+875
-0
lines changed

7 files changed

+875
-0
lines changed

SDMatte/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# 使 SDMatte 目录成为可导入包
2+
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .meta_arch import *
Lines changed: 235 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,235 @@
1+
import torch
2+
from transformers import CLIPTextModel, CLIPTokenizer, CLIPTextConfig
3+
import torch.nn as nn
4+
import torch.nn.functional as F
5+
from diffusers import DDIMScheduler, AutoencoderKL, UNet2DConditionModel
6+
from diffusers.models.embeddings import get_timestep_embedding
7+
from ...utils import replace_unet_conv_in, replace_attention_mask_method, add_aux_conv_in
8+
from ...utils.replace import CustomUNet
9+
import random
10+
import os
11+
12+
# 解决离线本地目录层级差异,如存在 "subdir/subdir/config.json" 的情况
13+
def _resolve_nested_dir(base_dir: str, subdir: str, config_filename: str) -> str:
14+
direct = os.path.join(base_dir, subdir)
15+
nested = os.path.join(base_dir, subdir, subdir)
16+
if os.path.exists(os.path.join(direct, config_filename)):
17+
return direct
18+
if os.path.exists(os.path.join(nested, config_filename)):
19+
return nested
20+
return direct
21+
22+
AUX_INPUT_DIT = {
23+
"auto_mask": "auto_coords",
24+
"point_mask": "point_coords",
25+
"bbox_mask": "bbox_coords",
26+
"mask": "mask_coords",
27+
"trimap": "trimap_coords",
28+
}
29+
30+
class SDMatte(nn.Module):
31+
def __init__(
32+
self,
33+
pretrained_model_name_or_path,
34+
conv_scale=3,
35+
num_inference_steps=1,
36+
aux_input="bbox_mask",
37+
use_aux_input=False,
38+
use_coor_input=True,
39+
use_dis_loss=True,
40+
use_attention_mask=True,
41+
use_encoder_attention_mask=False,
42+
add_noise=False,
43+
attn_mask_aux_input=["point_mask", "bbox_mask", "mask"],
44+
aux_input_list=["point_mask", "bbox_mask", "mask"],
45+
use_encoder_hidden_states=True,
46+
residual_connection=False,
47+
use_attention_mask_list=[True, True, True],
48+
use_encoder_hidden_states_list=[True, True, True],
49+
load_weight = True,
50+
):
51+
super().__init__()
52+
self.init_submodule(pretrained_model_name_or_path, load_weight)
53+
self.num_inference_steps = num_inference_steps
54+
self.aux_input = aux_input
55+
self.use_aux_input = use_aux_input
56+
self.use_coor_input = use_coor_input
57+
self.use_dis_loss = use_dis_loss
58+
self.use_attention_mask = use_attention_mask
59+
self.use_encoder_attention_mask = use_encoder_attention_mask
60+
self.add_noise = add_noise
61+
self.attn_mask_aux_input = attn_mask_aux_input
62+
self.aux_input_list = aux_input_list
63+
self.use_encoder_hidden_states = use_encoder_hidden_states
64+
if use_encoder_hidden_states:
65+
self.unet = add_aux_conv_in(self.unet)
66+
if not add_noise:
67+
conv_scale -= 1
68+
if not use_aux_input:
69+
conv_scale -= 1
70+
if conv_scale > 1:
71+
self.unet = replace_unet_conv_in(self.unet, conv_scale)
72+
replace_attention_mask_method(self.unet, residual_connection)
73+
self.text_encoder.requires_grad_(False)
74+
self.vae.requires_grad_(False)
75+
self.unet.train()
76+
self.unet.use_attention_mask_list = use_attention_mask_list
77+
self.unet.use_encoder_hidden_states_list = use_encoder_hidden_states_list
78+
79+
def init_submodule(self, pretrained_model_name_or_path, load_weight):
80+
if load_weight:
81+
text_dir = _resolve_nested_dir(pretrained_model_name_or_path, "text_encoder", "config.json")
82+
vae_dir = _resolve_nested_dir(pretrained_model_name_or_path, "vae", "config.json")
83+
unet_dir = _resolve_nested_dir(pretrained_model_name_or_path, "unet", "config.json")
84+
sched_dir = _resolve_nested_dir(pretrained_model_name_or_path, "scheduler", "scheduler_config.json")
85+
tok_dir = _resolve_nested_dir(pretrained_model_name_or_path, "tokenizer", "tokenizer_config.json")
86+
87+
self.text_encoder = CLIPTextModel.from_pretrained(text_dir)
88+
self.vae = AutoencoderKL.from_pretrained(vae_dir)
89+
self.unet = CustomUNet.from_pretrained(
90+
unet_dir, low_cpu_mem_usage=True, ignore_mismatched_sizes=False
91+
)
92+
self.noise_scheduler = DDIMScheduler.from_pretrained(sched_dir)
93+
self.tokenizer = CLIPTokenizer.from_pretrained(tok_dir)
94+
else:
95+
text_dir = _resolve_nested_dir(pretrained_model_name_or_path, "text_encoder", "config.json")
96+
text_config = CLIPTextConfig.from_pretrained(text_dir)
97+
self.text_encoder = CLIPTextModel(text_config)
98+
99+
vae_path = _resolve_nested_dir(pretrained_model_name_or_path, "vae", "config.json")
100+
self.vae = AutoencoderKL.from_config(AutoencoderKL.load_config(vae_path))
101+
102+
unet_path = _resolve_nested_dir(pretrained_model_name_or_path, "unet", "config.json")
103+
self.unet = CustomUNet.from_config(
104+
CustomUNet.load_config(unet_path),
105+
low_cpu_mem_usage=True,
106+
ignore_mismatched_sizes=False
107+
)
108+
109+
scheduler_path = os.path.join(_resolve_nested_dir(pretrained_model_name_or_path, "scheduler", "scheduler_config.json"), "scheduler_config.json")
110+
self.noise_scheduler = DDIMScheduler.from_config(DDIMScheduler.load_config(scheduler_path))
111+
112+
tok_dir = _resolve_nested_dir(pretrained_model_name_or_path, "tokenizer", "tokenizer_config.json")
113+
self.tokenizer = CLIPTokenizer.from_pretrained(tok_dir)
114+
115+
116+
def forward(self, data):
117+
rgb = data["image"].cuda()
118+
B = rgb.shape[0]
119+
120+
if self.aux_input is None and self.training:
121+
aux_input_type = random.choice(self.aux_input_list)
122+
elif self.aux_input is None:
123+
aux_input_type = "point_mask"
124+
else:
125+
aux_input_type = self.aux_input
126+
127+
# get aux input latent
128+
if self.use_aux_input:
129+
aux_input = data[aux_input_type].cuda()
130+
aux_input = aux_input.repeat(1, 3, 1, 1)
131+
aux_input_h = self.vae.encoder(aux_input.to(rgb.dtype))
132+
aux_input_moments = self.vae.quant_conv(aux_input_h)
133+
aux_input_mean, _ = torch.chunk(aux_input_moments, 2, dim=1)
134+
aux_input_latent = aux_input_mean * self.vae.config.scaling_factor
135+
else:
136+
aux_input_latent = None
137+
138+
# get aux coordinate
139+
coor_name = AUX_INPUT_DIT[aux_input_type]
140+
coor = data[coor_name].cuda()
141+
if coor_name == "point_coords":
142+
N = coor.shape[1]
143+
for i in range(N, 1680):
144+
if 1680 % i == 0:
145+
num_channels = 1680 // i
146+
pad_size = i - N
147+
padding = torch.zeros((B, pad_size), dtype=coor.dtype, device=coor.device)
148+
coor = torch.cat([coor, padding], dim=1)
149+
zero_coor = torch.zeros((B, pad_size + N), dtype=coor.dtype, device=coor.device)
150+
break
151+
if self.use_coor_input:
152+
coor = get_timestep_embedding(
153+
coor.flatten(),
154+
num_channels,
155+
flip_sin_to_cos=True,
156+
downscale_freq_shift=0,
157+
)
158+
else:
159+
coor = get_timestep_embedding(
160+
zero_coor.flatten(),
161+
num_channels,
162+
flip_sin_to_cos=True,
163+
downscale_freq_shift=0,
164+
)
165+
added_cond_kwargs = {"point_coords": coor}
166+
else:
167+
if self.use_coor_input:
168+
added_cond_kwargs = {"bbox_mask_coords": coor}
169+
else:
170+
coor = torch.tensor([[0, 0, 1, 1]] * B).cuda()
171+
added_cond_kwargs = {"bbox_mask_coords": coor}
172+
173+
# get attention mask
174+
if self.use_attention_mask and aux_input_type in self.attn_mask_aux_input:
175+
attention_mask = data[aux_input_type].cuda()
176+
attention_mask = (attention_mask + 1) / 2
177+
attention_mask = F.interpolate(attention_mask, scale_factor=1 / 8, mode="nearest")
178+
attention_mask = attention_mask.flatten(start_dim=1)
179+
else:
180+
attention_mask = None
181+
182+
# encode rgb to latents
183+
rgb_h = self.vae.encoder(rgb)
184+
rgb_moments = self.vae.quant_conv(rgb_h)
185+
rgb_mean, _ = torch.chunk(rgb_moments, 2, dim=1)
186+
rgb_latent = rgb_mean * self.vae.config.scaling_factor
187+
188+
# get encoder_hidden_states
189+
if self.use_encoder_hidden_states and aux_input_latent is not None:
190+
encoder_hidden_states = self.unet.aux_conv_in(aux_input_latent)
191+
encoder_hidden_states = encoder_hidden_states.view(B, 1024, -1)
192+
encoder_hidden_states = encoder_hidden_states.permute(0, 2, 1)
193+
194+
if "caption" in data:
195+
prompt = data["caption"]
196+
else:
197+
prompt = [""] * B
198+
prompt = [prompt] if isinstance(prompt, str) else prompt
199+
text_inputs = self.tokenizer(
200+
prompt,
201+
padding="max_length",
202+
max_length=self.tokenizer.model_max_length,
203+
truncation=True,
204+
return_tensors="pt",
205+
)
206+
text_input_ids = text_inputs.input_ids.to("cuda")
207+
text_embed = self.text_encoder(text_input_ids)[0]
208+
encoder_hidden_states_2 = text_embed
209+
210+
# get class_label
211+
is_trans = data["is_trans"].cuda()
212+
trans = 1 - is_trans
213+
214+
# get timesteps
215+
timestep = torch.tensor([1], device="cuda").long()
216+
217+
# unet
218+
unet_input = torch.cat([rgb_latent, aux_input_latent], dim=1)
219+
label_latent = self.unet(
220+
sample=unet_input,
221+
trans=trans,
222+
timestep=None,
223+
encoder_hidden_states=encoder_hidden_states,
224+
encoder_hidden_states_2=encoder_hidden_states_2,
225+
added_cond_kwargs=added_cond_kwargs,
226+
attention_mask=attention_mask,
227+
).sample
228+
label_latent = label_latent / self.vae.config.scaling_factor
229+
z = self.vae.post_quant_conv(label_latent)
230+
stacked = self.vae.decoder(z)
231+
# mean of output channels
232+
label_mean = stacked.mean(dim=1, keepdim=True)
233+
output = torch.clip(label_mean, -1.0, 1.0)
234+
output = (output + 1.0) / 2.0
235+
return output

SDMatte/modeling/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .SDMatte import *
2+
# from .LiteSDMatte import *

SDMatte/utils/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .utils import *
2+
from .replace import *

0 commit comments

Comments
 (0)