Skip to content

Commit 73271ba

Browse files
committed
Update viclip_text.py
1 parent 59dc96a commit 73271ba

File tree

1 file changed

+290
-74
lines changed

1 file changed

+290
-74
lines changed
Lines changed: 290 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,81 +1,297 @@
1+
import os
2+
import logging
3+
from collections import OrderedDict
4+
from pkg_resources import packaging
15
from .simple_tokenizer import SimpleTokenizer as _Tokenizer
2-
from .viclip import ViCLIP
3-
import torch
6+
47
import numpy as np
5-
import cv2
6-
import os
8+
import torch
9+
import torch.nn.functional as F
10+
from torch import nn
11+
import torch.utils.checkpoint as checkpoint
12+
import functools
13+
14+
logger = logging.getLogger(__name__)
15+
16+
17+
# On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
18+
MODEL_PATH = 'https://huggingface.co/laion'
19+
_MODELS = {
20+
"ViT-L/14": os.path.join(MODEL_PATH, "CLIP-ViT-L-14-DataComp.XL-s13B-b90K", "vit_l14_text.pth"),
21+
"ViT-B/16": os.path.join(MODEL_PATH, "CLIP-ViT-B-16-DataComp.XL-s13B-b90K", "vit_b16_text.pth"),
22+
}
23+
24+
25+
class LayerNorm(nn.LayerNorm):
26+
"""Subclass torch's LayerNorm to handle fp16."""
27+
28+
def forward(self, x: torch.Tensor):
29+
orig_type = x.dtype
30+
ret = super().forward(x.type(torch.float32))
31+
return ret.type(orig_type)
32+
733

8-
clip_candidates = {'viclip':None, 'clip':None}
9-
10-
def get_clip(name='viclip',
11-
size='l',
12-
pretrain=os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth"),
13-
force_reload=False):
14-
global clip_candidates
15-
m = clip_candidates[name]
16-
if m is None or force_reload:
17-
if name == 'viclip':
18-
tokenizer = _Tokenizer()
19-
vclip = ViCLIP(tokenizer=tokenizer, size=size, pretrain=pretrain)
20-
# m = vclip
21-
m = (vclip, tokenizer)
34+
class QuickGELU(nn.Module):
35+
def forward(self, x: torch.Tensor):
36+
return x * torch.sigmoid(1.702 * x)
37+
38+
39+
class ResidualAttentionBlock(nn.Module):
40+
def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):
41+
super().__init__()
42+
43+
self.attn = nn.MultiheadAttention(d_model, n_head)
44+
self.ln_1 = LayerNorm(d_model)
45+
self.mlp = nn.Sequential(OrderedDict([
46+
("c_fc", nn.Linear(d_model, d_model * 4)),
47+
("gelu", QuickGELU()),
48+
("c_proj", nn.Linear(d_model * 4, d_model))
49+
]))
50+
self.ln_2 = LayerNorm(d_model)
51+
self.attn_mask = attn_mask
52+
53+
def attention(self, x: torch.Tensor):
54+
self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None
55+
return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]
56+
57+
def forward(self, x: torch.Tensor):
58+
x = x + self.attention(self.ln_1(x))
59+
x = x + self.mlp(self.ln_2(x))
60+
return x
61+
62+
63+
class Transformer(nn.Module):
64+
def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None,
65+
checkpoint_num: int = 0):
66+
super().__init__()
67+
self.width = width
68+
self.layers = layers
69+
self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])
70+
71+
self.checkpoint_num = checkpoint_num
72+
73+
def forward(self, x: torch.Tensor):
74+
if self.checkpoint_num > 0:
75+
segments = min(self.checkpoint_num, len(self.resblocks))
76+
return checkpoint.checkpoint_sequential(self.resblocks, segments, x)
2277
else:
23-
raise Exception('the target clip model is not found.')
78+
return self.resblocks(x)
79+
80+
81+
class CLIP_TEXT(nn.Module):
82+
def __init__(
83+
self,
84+
embed_dim: int,
85+
context_length: int,
86+
vocab_size: int,
87+
transformer_width: int,
88+
transformer_heads: int,
89+
transformer_layers: int,
90+
checkpoint_num: int,
91+
):
92+
super().__init__()
93+
94+
self.context_length = context_length
95+
self._tokenizer = _Tokenizer()
96+
97+
self.transformer = Transformer(
98+
width=transformer_width,
99+
layers=transformer_layers,
100+
heads=transformer_heads,
101+
attn_mask=self.build_attention_mask(),
102+
checkpoint_num=checkpoint_num,
103+
)
104+
105+
self.vocab_size = vocab_size
106+
self.token_embedding = nn.Embedding(vocab_size, transformer_width)
107+
self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))
108+
self.ln_final = LayerNorm(transformer_width)
109+
110+
self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))
24111

25-
return m
26-
27-
def get_text_feat_dict(texts, clip, tokenizer, text_feat_d={}):
28-
for t in texts:
29-
feat = clip.get_text_features(t, tokenizer, text_feat_d)
30-
text_feat_d[t] = feat
31-
return text_feat_d
32-
33-
def get_vid_feat(frames, clip):
34-
return clip.get_vid_features(frames)
35-
36-
def _frame_from_video(video):
37-
while video.isOpened():
38-
success, frame = video.read()
39-
if success:
40-
yield frame
112+
def no_weight_decay(self):
113+
return {'token_embedding', 'positional_embedding'}
114+
115+
@functools.lru_cache(maxsize=None)
116+
def build_attention_mask(self):
117+
# lazily create causal attention mask, with full attention between the vision tokens
118+
# pytorch uses additive attention mask; fill with -inf
119+
mask = torch.empty(self.context_length, self.context_length)
120+
mask.fill_(float("-inf"))
121+
mask.triu_(1) # zero out the lower diagonal
122+
return mask
123+
124+
def tokenize(self, texts, context_length=77, truncate=True):
125+
"""
126+
Returns the tokenized representation of given input string(s)
127+
Parameters
128+
----------
129+
texts : Union[str, List[str]]
130+
An input string or a list of input strings to tokenize
131+
context_length : int
132+
The context length to use; all CLIP models use 77 as the context length
133+
truncate: bool
134+
Whether to truncate the text in case its encoding is longer than the context length
135+
Returns
136+
-------
137+
A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
138+
We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
139+
"""
140+
if isinstance(texts, str):
141+
texts = [texts]
142+
143+
sot_token = self._tokenizer.encoder["<|startoftext|>"]
144+
eot_token = self._tokenizer.encoder["<|endoftext|>"]
145+
all_tokens = [[sot_token] + self._tokenizer.encode(text) + [eot_token] for text in texts]
146+
if packaging.version.parse(torch.__version__) < packaging.version.parse("1.8.0"):
147+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.long)
41148
else:
42-
break
43-
44-
v_mean = np.array([0.485, 0.456, 0.406]).reshape(1,1,3)
45-
v_std = np.array([0.229, 0.224, 0.225]).reshape(1,1,3)
46-
def normalize(data):
47-
return (data/255.0-v_mean)/v_std
48-
49-
def frames2tensor(vid_list, fnum=8, target_size=(224, 224), device=torch.device('cuda')):
50-
assert(len(vid_list) >= fnum)
51-
step = len(vid_list) // fnum
52-
vid_list = vid_list[::step][:fnum]
53-
vid_list = [cv2.resize(x[:,:,::-1], target_size) for x in vid_list]
54-
vid_tube = [np.expand_dims(normalize(x), axis=(0, 1)) for x in vid_list]
55-
vid_tube = np.concatenate(vid_tube, axis=1)
56-
vid_tube = np.transpose(vid_tube, (0, 1, 4, 2, 3))
57-
vid_tube = torch.from_numpy(vid_tube).to(device, non_blocking=True).float()
58-
return vid_tube
59-
60-
def retrieve_text(frames,
61-
texts,
62-
name='viclip',
63-
model_cfg={'size':'l',
64-
'pretrained': 'os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth")',
65-
'reload':False},
66-
topk=5,
67-
device=torch.device('cuda')):
68-
clip, tokenizer = get_clip(name, model_cfg['size'], model_cfg['pretrained'], model_cfg['reload'])
69-
clip = clip.to(device)
70-
frames_tensor = frames2tensor(frames, device=device)
71-
vid_feat = get_vid_feat(frames_tensor, clip)
72-
73-
text_feat_d = {}
74-
text_feat_d = get_text_feat_dict(texts, clip, tokenizer, text_feat_d)
75-
text_feats = [text_feat_d[t] for t in texts]
76-
text_feats_tensor = torch.cat(text_feats, 0)
77-
78-
probs, idxs = clip.get_predict_label(vid_feat, text_feats_tensor, top=topk)
149+
result = torch.zeros(len(all_tokens), context_length, dtype=torch.int)
150+
151+
for i, tokens in enumerate(all_tokens):
152+
if len(tokens) > context_length:
153+
if truncate:
154+
tokens = tokens[:context_length]
155+
tokens[-1] = eot_token
156+
else:
157+
raise RuntimeError(f"Input {texts[i]} is too long for context length {context_length}")
158+
result[i, :len(tokens)] = torch.tensor(tokens)
159+
160+
return result
161+
162+
def forward(self, text):
163+
x = self.token_embedding(text) # [batch_size, n_ctx, d_model]
164+
165+
x = x + self.positional_embedding
166+
x = x.permute(1, 0, 2) # NLD -> LND
167+
x = self.transformer(x)
168+
x = x.permute(1, 0, 2) # LND -> NLD
169+
x = self.ln_final(x)
170+
171+
# x.shape = [batch_size, n_ctx, transformer.width]
172+
# take features from the eot embedding (eot_token is the highest number in each sequence)
173+
x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection
174+
175+
return x
176+
177+
178+
def clip_text_b16(
179+
embed_dim=512,
180+
context_length=77,
181+
vocab_size=49408,
182+
transformer_width=512,
183+
transformer_heads=8,
184+
transformer_layers=12,
185+
checkpoint_num=0,
186+
pretrained=True,
187+
):
188+
# raise NotImplementedError
189+
model = CLIP_TEXT(
190+
embed_dim,
191+
context_length,
192+
vocab_size,
193+
transformer_width,
194+
transformer_heads,
195+
transformer_layers,
196+
checkpoint_num,
197+
)
198+
# pretrained = _MODELS["ViT-B/16"]
199+
# logger.info(f"Load pretrained weights from {pretrained}")
200+
# state_dict = torch.load(pretrained, map_location='cpu')
201+
# model.load_state_dict(state_dict, strict=False)
202+
# return model.eval()
203+
if pretrained:
204+
if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
205+
pretrained = _MODELS[pretrained]
206+
else:
207+
pretrained = _MODELS["ViT-B/16"]
208+
logger.info(f"Load pretrained weights from {pretrained}")
209+
state_dict = torch.load(pretrained, map_location='cpu')
210+
if context_length != state_dict["positional_embedding"].size(0):
211+
# assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
212+
print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
213+
if context_length < state_dict["positional_embedding"].size(0):
214+
state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
215+
else:
216+
state_dict["positional_embedding"] = F.pad(
217+
state_dict["positional_embedding"],
218+
(0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
219+
value=0,
220+
)
221+
222+
message = model.load_state_dict(state_dict, strict=False)
223+
print(f"Load pretrained weights from {pretrained}: {message}")
224+
return model.eval()
225+
226+
227+
def clip_text_l14(
228+
embed_dim=768,
229+
context_length=77,
230+
vocab_size=49408,
231+
transformer_width=768,
232+
transformer_heads=12,
233+
transformer_layers=12,
234+
checkpoint_num=0,
235+
pretrained=True,
236+
):
237+
model = CLIP_TEXT(
238+
embed_dim,
239+
context_length,
240+
vocab_size,
241+
transformer_width,
242+
transformer_heads,
243+
transformer_layers,
244+
checkpoint_num,
245+
)
246+
if pretrained:
247+
if isinstance(pretrained, str) and pretrained != "bert-base-uncased":
248+
pretrained = _MODELS[pretrained]
249+
else:
250+
pretrained = _MODELS["ViT-L/14"]
251+
logger.info(f"Load pretrained weights from {pretrained}")
252+
state_dict = torch.load(pretrained, map_location='cpu')
253+
if context_length != state_dict["positional_embedding"].size(0):
254+
# assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
255+
print(f"Resize positional embedding from {state_dict['positional_embedding'].size(0)} to {context_length}")
256+
if context_length < state_dict["positional_embedding"].size(0):
257+
state_dict["positional_embedding"] = state_dict["positional_embedding"][:context_length]
258+
else:
259+
state_dict["positional_embedding"] = F.pad(
260+
state_dict["positional_embedding"],
261+
(0, 0, 0, context_length - state_dict["positional_embedding"].size(0)),
262+
value=0,
263+
)
264+
265+
message = model.load_state_dict(state_dict, strict=False)
266+
print(f"Load pretrained weights from {pretrained}: {message}")
267+
return model.eval()
268+
269+
270+
def clip_text_l14_336(
271+
embed_dim=768,
272+
context_length=77,
273+
vocab_size=49408,
274+
transformer_width=768,
275+
transformer_heads=12,
276+
transformer_layers=12,
277+
):
278+
raise NotImplementedError
279+
model = CLIP_TEXT(
280+
embed_dim,
281+
context_length,
282+
vocab_size,
283+
transformer_width,
284+
transformer_heads,
285+
transformer_layers
286+
)
287+
pretrained = _MODELS["ViT-L/14_336"]
288+
logger.info(f"Load pretrained weights from {pretrained}")
289+
state_dict = torch.load(pretrained, map_location='cpu')
290+
model.load_state_dict(state_dict, strict=False)
291+
return model.eval()
292+
79293

80-
ret_texts = [texts[i] for i in idxs.numpy()[0].tolist()]
81-
return ret_texts, probs.numpy()[0]
294+
def build_clip(config):
295+
model_cls = config.text_encoder.clip_teacher
296+
model = eval(model_cls)()
297+
return model

0 commit comments

Comments
 (0)