Skip to content

Commit 12855b2

Browse files
committed
update consisid
1 parent 61c85f7 commit 12855b2

File tree

17 files changed

+234
-143
lines changed

17 files changed

+234
-143
lines changed

src/diffusers/models/transformers/consisid_transformer_3d.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -75,9 +75,9 @@ def reshape_tensor(x, heads):
7575
class PerceiverAttention(nn.Module):
7676
"""
7777
Implements the Perceiver attention mechanism with multi-head attention.
78-
78+
7979
This layer takes two inputs: 'x' (image features) and 'latents' (latent features),
80-
applying multi-head attention to both and producing an output tensor with the same
80+
applying multi-head attention to both and producing an output tensor with the same
8181
dimension as the input tensor 'x'.
8282
8383
Args:
@@ -522,19 +522,19 @@ class ConsisIDTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
522522
which can help reduce computational overhead.
523523
LFE_num_tokens (`int`, defaults to `32`):
524524
The number of tokens to use in the Local Facial Extractor (LFE).
525-
This module is responsible for capturing high frequency representations
525+
This module is responsible for capturing high frequency representations
526526
of the face.
527527
LFE_output_dim (`int`, defaults to `768`):
528528
The output dimension of the Local Facial Extractor (LFE) module.
529-
This dimension determines the size of the feature vectors produced
529+
This dimension determines the size of the feature vectors produced
530530
by the LFE module.
531531
LFE_heads (`int`, defaults to `12`):
532532
The number of attention heads used in the Local Facial Extractor (LFE) module.
533-
More heads may improve the ability to capture diverse features, but
533+
More heads may improve the ability to capture diverse features, but
534534
can also increase computational complexity.
535535
local_face_scale (`float`, defaults to `1.0`):
536-
A scaling factor used to adjust the importance of local facial features
537-
in the model. This can influence how strongly the model focuses on
536+
A scaling factor used to adjust the importance of local facial features
537+
in the model. This can influence how strongly the model focuses on
538538
high frequency face-related content.
539539
"""
540540

Lines changed: 33 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,36 @@
11
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
2-
from .factory import create_model, create_model_and_transforms, create_model_from_pretrained, get_tokenizer, create_transforms
3-
from .factory import list_models, add_model_config, get_model_config, load_checkpoint
2+
from .factory import (
3+
add_model_config,
4+
create_model,
5+
create_model_and_transforms,
6+
create_model_from_pretrained,
7+
create_transforms,
8+
get_model_config,
9+
get_tokenizer,
10+
list_models,
11+
load_checkpoint,
12+
)
413
from .loss import ClipLoss
5-
from .model import CLIP, CustomCLIP, CLIPTextCfg, CLIPVisionCfg,\
6-
convert_weights_to_lp, convert_weights_to_fp16, trace_model, get_cast_dtype
7-
from .openai import load_openai_model, list_openai_models
8-
from .pretrained import list_pretrained, list_pretrained_models_by_tag, list_pretrained_tags_by_model,\
9-
get_pretrained_url, download_pretrained_from_url, is_pretrained_cfg, get_pretrained_cfg, download_pretrained
14+
from .model import (
15+
CLIP,
16+
CLIPTextCfg,
17+
CLIPVisionCfg,
18+
CustomCLIP,
19+
convert_weights_to_fp16,
20+
convert_weights_to_lp,
21+
get_cast_dtype,
22+
trace_model,
23+
)
24+
from .openai import list_openai_models, load_openai_model
25+
from .pretrained import (
26+
download_pretrained,
27+
download_pretrained_from_url,
28+
get_pretrained_cfg,
29+
get_pretrained_url,
30+
is_pretrained_cfg,
31+
list_pretrained,
32+
list_pretrained_models_by_tag,
33+
list_pretrained_tags_by_model,
34+
)
1035
from .tokenizer import SimpleTokenizer, tokenize
11-
from .transform import image_transform
36+
from .transform import image_transform

src/diffusers/pipelines/consisid/util_clip/eva_vit_model.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,20 @@
44
import math
55
import os
66
from functools import partial
7+
78
import torch
89
import torch.nn as nn
910
import torch.nn.functional as F
11+
12+
1013
try:
1114
from timm.models.layers import drop_path, to_2tuple, trunc_normal_
1215
except:
1316
from timm.layers import drop_path, to_2tuple, trunc_normal_
14-
17+
18+
from .rope import VisionRotaryEmbeddingFast
1519
from .transformer import PatchDropout
16-
from .rope import VisionRotaryEmbedding, VisionRotaryEmbeddingFast
20+
1721

1822
if os.getenv('ENV_TYPE') == 'deepspeed':
1923
try:
@@ -24,7 +28,6 @@
2428
from torch.utils.checkpoint import checkpoint
2529

2630
try:
27-
import xformers
2831
import xformers.ops as xops
2932
XFORMERS_IS_AVAILBLE = True
3033
except:
@@ -39,19 +42,19 @@ def __init__(self, drop_prob=None):
3942

4043
def forward(self, x):
4144
return drop_path(x, self.drop_prob, self.training)
42-
45+
4346
def extra_repr(self) -> str:
4447
return 'p={}'.format(self.drop_prob)
4548

4649

4750
class Mlp(nn.Module):
4851
def __init__(
49-
self,
50-
in_features,
51-
hidden_features=None,
52-
out_features=None,
53-
act_layer=nn.GELU,
54-
norm_layer=nn.LayerNorm,
52+
self,
53+
in_features,
54+
hidden_features=None,
55+
out_features=None,
56+
act_layer=nn.GELU,
57+
norm_layer=nn.LayerNorm,
5558
drop=0.,
5659
subln=False,
5760

@@ -71,15 +74,15 @@ def forward(self, x):
7174
x = self.fc1(x)
7275
x = self.act(x)
7376
# x = self.drop(x)
74-
# commit this for the orignal BERT implement
77+
# commit this for the orignal BERT implement
7578
x = self.ffn_ln(x)
7679

7780
x = self.fc2(x)
7881
x = self.drop(x)
7982
return x
8083

8184
class SwiGLU(nn.Module):
82-
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
85+
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.SiLU, drop=0.,
8386
norm_layer=nn.LayerNorm, subln=False):
8487
super().__init__()
8588
out_features = out_features or in_features
@@ -91,7 +94,7 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
9194
self.act = act_layer()
9295
self.ffn_ln = norm_layer(hidden_features) if subln else nn.Identity()
9396
self.w3 = nn.Linear(hidden_features, out_features)
94-
97+
9598
self.drop = nn.Dropout(drop)
9699

97100
def forward(self, x):
@@ -172,20 +175,20 @@ def __init__(
172175

173176
def forward(self, x, rel_pos_bias=None, attn_mask=None):
174177
B, N, C = x.shape
175-
if self.subln:
178+
if self.subln:
176179
q = F.linear(input=x, weight=self.q_proj.weight, bias=self.q_bias)
177180
k = F.linear(input=x, weight=self.k_proj.weight, bias=None)
178181
v = F.linear(input=x, weight=self.v_proj.weight, bias=self.v_bias)
179182

180183
q = q.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3) # B, num_heads, N, C
181-
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
182-
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
183-
else:
184+
k = k.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
185+
v = v.reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
186+
else:
184187

185188
qkv_bias = None
186189
if self.q_bias is not None:
187190
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
188-
191+
189192
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
190193
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) # 3, B, num_heads, N, C
191194
q, k, v = qkv[0], qkv[1], qkv[2]
@@ -232,7 +235,7 @@ def forward(self, x, rel_pos_bias=None, attn_mask=None):
232235
if attn_mask is not None:
233236
attn_mask = attn_mask.bool()
234237
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
235-
238+
236239
attn = attn.softmax(dim=-1)
237240
attn = self.attn_drop(attn)
238241

@@ -262,15 +265,15 @@ def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None,
262265

263266
if naiveswiglu:
264267
self.mlp = SwiGLU(
265-
in_features=dim,
266-
hidden_features=mlp_hidden_dim,
268+
in_features=dim,
269+
hidden_features=mlp_hidden_dim,
267270
subln=subln,
268271
norm_layer=norm_layer,
269272
)
270273
else:
271274
self.mlp = Mlp(
272-
in_features=dim,
273-
hidden_features=mlp_hidden_dim,
275+
in_features=dim,
276+
hidden_features=mlp_hidden_dim,
274277
act_layer=act_layer,
275278
subln=subln,
276279
drop=drop
@@ -407,7 +410,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, em
407410
ft_seq_len=hw_seq_len if intp_freq else None,
408411
# patch_dropout=patch_dropout
409412
)
410-
else:
413+
else:
411414
self.rope = None
412415

413416
self.naiveswiglu = naiveswiglu
@@ -469,7 +472,7 @@ def _init_weights(self, m):
469472

470473
def get_num_layers(self):
471474
return len(self.blocks)
472-
475+
473476
def lock(self, unlocked_groups=0, freeze_bn_stats=False):
474477
assert unlocked_groups == 0, 'partial locking not currently supported for this model'
475478
for param in self.parameters():
@@ -491,7 +494,7 @@ def reset_classifier(self, num_classes, global_pool=''):
491494
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
492495

493496
def forward_features(self, x, return_all_features=False, return_hidden=False, shuffle=False):
494-
497+
495498
x = self.patch_embed(x)
496499
batch_size, seq_len, _ = x.size()
497500

src/diffusers/pipelines/consisid/util_clip/factory.py

Lines changed: 15 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,23 @@
11
import json
22
import logging
33
import os
4-
import pathlib
54
import re
65
from copy import deepcopy
76
from pathlib import Path
8-
from typing import Optional, Tuple, Union, Dict, Any
7+
from typing import Optional, Tuple, Union
8+
99
import torch
1010

1111
from .constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
12-
from .model import CLIP, CustomCLIP, convert_weights_to_lp, convert_to_custom_text_state_dict,\
13-
get_cast_dtype
12+
from .model import CLIP, CustomCLIP, convert_to_custom_text_state_dict, get_cast_dtype
1413
from .openai import load_openai_model
15-
from .pretrained import is_pretrained_cfg, get_pretrained_cfg, download_pretrained, list_pretrained_tags_by_model
16-
from .transform import image_transform
14+
from .pretrained import download_pretrained, get_pretrained_cfg, is_pretrained_cfg, list_pretrained_tags_by_model
1715
from .tokenizer import HFTokenizer, tokenize
18-
from .utils import resize_clip_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed, resize_eva_pos_embed
16+
from .transform import image_transform
17+
from .utils import resize_clip_pos_embed, resize_eva_pos_embed, resize_evaclip_pos_embed, resize_visual_pos_embed
1918

2019

21-
_MODEL_CONFIG_PATHS = [Path(__file__).parent / f"model_configs/"]
20+
_MODEL_CONFIG_PATHS = [Path(__file__).parent / "model_configs/"]
2221
_MODEL_CONFIGS = {} # directory (model_name: config) of model architecture configs
2322

2423

@@ -93,7 +92,7 @@ def load_state_dict(checkpoint_path: str, map_location: str='cpu', model_key: st
9392
state_dict = checkpoint
9493
if next(iter(state_dict.items()))[0].startswith('module'):
9594
state_dict = {k[7:]: v for k, v in state_dict.items()}
96-
95+
9796
for k in skip_list:
9897
if k in list(state_dict.keys()):
9998
logging.info(f"Removing key {k} from pretrained checkpoint")
@@ -181,7 +180,7 @@ def load_pretrained_checkpoint(
181180
visual_state_dict = load_clip_visual_state_dict(visual_checkpoint_path, is_openai=True, skip_list=skip_list)
182181
else:
183182
visual_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
184-
183+
185184
# resize_clip_pos_embed for CLIP and open CLIP
186185
if 'positional_embedding' in visual_state_dict:
187186
resize_visual_pos_embed(visual_state_dict, model)
@@ -202,7 +201,7 @@ def load_pretrained_checkpoint(
202201
text_state_dict = load_state_dict(visual_checkpoint_path, model_key=model_key, is_openai=False, skip_list=skip_list)
203202

204203
text_incompatible_keys = model.text.load_state_dict(text_state_dict, strict=strict)
205-
204+
206205
logging.info(f"num of loaded text_state_dict keys: {len(text_state_dict.keys())}")
207206
logging.info(f"text_incompatible_keys.missing_keys: {text_incompatible_keys.missing_keys}")
208207

@@ -255,7 +254,7 @@ def create_model(
255254
if force_quick_gelu:
256255
# override for use of QuickGELU on non-OpenAI transformer models
257256
model_cfg["quick_gelu"] = True
258-
257+
259258
if force_patch_dropout is not None:
260259
# override the default patch dropout value
261260
model_cfg['vision_cfg']["patch_dropout"] = force_patch_dropout
@@ -286,7 +285,7 @@ def create_model(
286285
checkpoint_path,
287286
model_key="model|module|state_dict",
288287
strict=False
289-
)
288+
)
290289
else:
291290
error_str = (
292291
f'Pretrained weights ({pretrained}) not found for model {model_name}.'
@@ -296,7 +295,7 @@ def create_model(
296295
else:
297296
visual_checkpoint_path = ''
298297
text_checkpoint_path = ''
299-
298+
300299
if pretrained_image:
301300
pretrained_visual_model = pretrained_visual_model.replace('/', '-') # for callers using old naming with / in ViT names
302301
pretrained_image_cfg = get_pretrained_cfg(pretrained_visual_model, pretrained_image)
@@ -321,7 +320,7 @@ def create_model(
321320
else:
322321
logging.warning(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
323322
raise RuntimeError(f'Pretrained weights ({text_checkpoint_path}) not found for model {model_name}.text.')
324-
323+
325324
if visual_checkpoint_path:
326325
logging.info(f'Loading pretrained {model_name}.visual weights ({visual_checkpoint_path}).')
327326
if text_checkpoint_path:
@@ -338,7 +337,7 @@ def create_model(
338337
model_key="model|module|state_dict",
339338
skip_list=skip_list
340339
)
341-
340+
342341
if "fp16" in precision or "bf16" in precision:
343342
logging.info(f'convert precision to {precision}')
344343
model = model.to(torch.bfloat16) if 'bf16' in precision else model.to(torch.float16)

0 commit comments

Comments
 (0)