1+ import os
2+ import logging
3+ from collections import OrderedDict
4+ from pkg_resources import packaging
15from .simple_tokenizer import SimpleTokenizer as _Tokenizer
2- from .viclip import ViCLIP
3- import torch
6+
47import numpy as np
5- import cv2
6- import os
8+ import torch
9+ import torch .nn .functional as F
10+ from torch import nn
11+ import torch .utils .checkpoint as checkpoint
12+ import functools
13+
14+ logger = logging .getLogger (__name__ )
15+
16+
17+ # On P1, model extracted from https://huggingface.co/laion/CLIP-ViT-L-14-DataComp.XL-s13B-b90K
18+ MODEL_PATH = 'https://huggingface.co/laion'
19+ _MODELS = {
20+ "ViT-L/14" : os .path .join (MODEL_PATH , "CLIP-ViT-L-14-DataComp.XL-s13B-b90K" , "vit_l14_text.pth" ),
21+ "ViT-B/16" : os .path .join (MODEL_PATH , "CLIP-ViT-B-16-DataComp.XL-s13B-b90K" , "vit_b16_text.pth" ),
22+ }
23+
24+
25+ class LayerNorm (nn .LayerNorm ):
26+ """Subclass torch's LayerNorm to handle fp16."""
27+
28+ def forward (self , x : torch .Tensor ):
29+ orig_type = x .dtype
30+ ret = super ().forward (x .type (torch .float32 ))
31+ return ret .type (orig_type )
32+
733
8- clip_candidates = {'viclip' :None , 'clip' :None }
9-
10- def get_clip (name = 'viclip' ,
11- size = 'l' ,
12- pretrain = os .path .join (os .path .dirname (os .path .abspath (__file__ )), "ViClip-InternVid-10M-FLT.pth" ),
13- force_reload = False ):
14- global clip_candidates
15- m = clip_candidates [name ]
16- if m is None or force_reload :
17- if name == 'viclip' :
18- tokenizer = _Tokenizer ()
19- vclip = ViCLIP (tokenizer = tokenizer , size = size , pretrain = pretrain )
20- # m = vclip
21- m = (vclip , tokenizer )
34+ class QuickGELU (nn .Module ):
35+ def forward (self , x : torch .Tensor ):
36+ return x * torch .sigmoid (1.702 * x )
37+
38+
39+ class ResidualAttentionBlock (nn .Module ):
40+ def __init__ (self , d_model : int , n_head : int , attn_mask : torch .Tensor = None ):
41+ super ().__init__ ()
42+
43+ self .attn = nn .MultiheadAttention (d_model , n_head )
44+ self .ln_1 = LayerNorm (d_model )
45+ self .mlp = nn .Sequential (OrderedDict ([
46+ ("c_fc" , nn .Linear (d_model , d_model * 4 )),
47+ ("gelu" , QuickGELU ()),
48+ ("c_proj" , nn .Linear (d_model * 4 , d_model ))
49+ ]))
50+ self .ln_2 = LayerNorm (d_model )
51+ self .attn_mask = attn_mask
52+
53+ def attention (self , x : torch .Tensor ):
54+ self .attn_mask = self .attn_mask .to (dtype = x .dtype , device = x .device ) if self .attn_mask is not None else None
55+ return self .attn (x , x , x , need_weights = False , attn_mask = self .attn_mask )[0 ]
56+
57+ def forward (self , x : torch .Tensor ):
58+ x = x + self .attention (self .ln_1 (x ))
59+ x = x + self .mlp (self .ln_2 (x ))
60+ return x
61+
62+
63+ class Transformer (nn .Module ):
64+ def __init__ (self , width : int , layers : int , heads : int , attn_mask : torch .Tensor = None ,
65+ checkpoint_num : int = 0 ):
66+ super ().__init__ ()
67+ self .width = width
68+ self .layers = layers
69+ self .resblocks = nn .Sequential (* [ResidualAttentionBlock (width , heads , attn_mask ) for _ in range (layers )])
70+
71+ self .checkpoint_num = checkpoint_num
72+
73+ def forward (self , x : torch .Tensor ):
74+ if self .checkpoint_num > 0 :
75+ segments = min (self .checkpoint_num , len (self .resblocks ))
76+ return checkpoint .checkpoint_sequential (self .resblocks , segments , x )
2277 else :
23- raise Exception ('the target clip model is not found.' )
78+ return self .resblocks (x )
79+
80+
81+ class CLIP_TEXT (nn .Module ):
82+ def __init__ (
83+ self ,
84+ embed_dim : int ,
85+ context_length : int ,
86+ vocab_size : int ,
87+ transformer_width : int ,
88+ transformer_heads : int ,
89+ transformer_layers : int ,
90+ checkpoint_num : int ,
91+ ):
92+ super ().__init__ ()
93+
94+ self .context_length = context_length
95+ self ._tokenizer = _Tokenizer ()
96+
97+ self .transformer = Transformer (
98+ width = transformer_width ,
99+ layers = transformer_layers ,
100+ heads = transformer_heads ,
101+ attn_mask = self .build_attention_mask (),
102+ checkpoint_num = checkpoint_num ,
103+ )
104+
105+ self .vocab_size = vocab_size
106+ self .token_embedding = nn .Embedding (vocab_size , transformer_width )
107+ self .positional_embedding = nn .Parameter (torch .empty (self .context_length , transformer_width ))
108+ self .ln_final = LayerNorm (transformer_width )
109+
110+ self .text_projection = nn .Parameter (torch .empty (transformer_width , embed_dim ))
24111
25- return m
26-
27- def get_text_feat_dict (texts , clip , tokenizer , text_feat_d = {}):
28- for t in texts :
29- feat = clip .get_text_features (t , tokenizer , text_feat_d )
30- text_feat_d [t ] = feat
31- return text_feat_d
32-
33- def get_vid_feat (frames , clip ):
34- return clip .get_vid_features (frames )
35-
36- def _frame_from_video (video ):
37- while video .isOpened ():
38- success , frame = video .read ()
39- if success :
40- yield frame
112+ def no_weight_decay (self ):
113+ return {'token_embedding' , 'positional_embedding' }
114+
115+ @functools .lru_cache (maxsize = None )
116+ def build_attention_mask (self ):
117+ # lazily create causal attention mask, with full attention between the vision tokens
118+ # pytorch uses additive attention mask; fill with -inf
119+ mask = torch .empty (self .context_length , self .context_length )
120+ mask .fill_ (float ("-inf" ))
121+ mask .triu_ (1 ) # zero out the lower diagonal
122+ return mask
123+
124+ def tokenize (self , texts , context_length = 77 , truncate = True ):
125+ """
126+ Returns the tokenized representation of given input string(s)
127+ Parameters
128+ ----------
129+ texts : Union[str, List[str]]
130+ An input string or a list of input strings to tokenize
131+ context_length : int
132+ The context length to use; all CLIP models use 77 as the context length
133+ truncate: bool
134+ Whether to truncate the text in case its encoding is longer than the context length
135+ Returns
136+ -------
137+ A two-dimensional tensor containing the resulting tokens, shape = [number of input strings, context_length].
138+ We return LongTensor when torch version is <1.8.0, since older index_select requires indices to be long.
139+ """
140+ if isinstance (texts , str ):
141+ texts = [texts ]
142+
143+ sot_token = self ._tokenizer .encoder ["<|startoftext|>" ]
144+ eot_token = self ._tokenizer .encoder ["<|endoftext|>" ]
145+ all_tokens = [[sot_token ] + self ._tokenizer .encode (text ) + [eot_token ] for text in texts ]
146+ if packaging .version .parse (torch .__version__ ) < packaging .version .parse ("1.8.0" ):
147+ result = torch .zeros (len (all_tokens ), context_length , dtype = torch .long )
41148 else :
42- break
43-
44- v_mean = np .array ([0.485 , 0.456 , 0.406 ]).reshape (1 ,1 ,3 )
45- v_std = np .array ([0.229 , 0.224 , 0.225 ]).reshape (1 ,1 ,3 )
46- def normalize (data ):
47- return (data / 255.0 - v_mean )/ v_std
48-
49- def frames2tensor (vid_list , fnum = 8 , target_size = (224 , 224 ), device = torch .device ('cuda' )):
50- assert (len (vid_list ) >= fnum )
51- step = len (vid_list ) // fnum
52- vid_list = vid_list [::step ][:fnum ]
53- vid_list = [cv2 .resize (x [:,:,::- 1 ], target_size ) for x in vid_list ]
54- vid_tube = [np .expand_dims (normalize (x ), axis = (0 , 1 )) for x in vid_list ]
55- vid_tube = np .concatenate (vid_tube , axis = 1 )
56- vid_tube = np .transpose (vid_tube , (0 , 1 , 4 , 2 , 3 ))
57- vid_tube = torch .from_numpy (vid_tube ).to (device , non_blocking = True ).float ()
58- return vid_tube
59-
60- def retrieve_text (frames ,
61- texts ,
62- name = 'viclip' ,
63- model_cfg = {'size' :'l' ,
64- 'pretrained' : 'os.path.join(os.path.dirname(os.path.abspath(__file__)), "ViClip-InternVid-10M-FLT.pth")' ,
65- 'reload' :False },
66- topk = 5 ,
67- device = torch .device ('cuda' )):
68- clip , tokenizer = get_clip (name , model_cfg ['size' ], model_cfg ['pretrained' ], model_cfg ['reload' ])
69- clip = clip .to (device )
70- frames_tensor = frames2tensor (frames , device = device )
71- vid_feat = get_vid_feat (frames_tensor , clip )
72-
73- text_feat_d = {}
74- text_feat_d = get_text_feat_dict (texts , clip , tokenizer , text_feat_d )
75- text_feats = [text_feat_d [t ] for t in texts ]
76- text_feats_tensor = torch .cat (text_feats , 0 )
77-
78- probs , idxs = clip .get_predict_label (vid_feat , text_feats_tensor , top = topk )
149+ result = torch .zeros (len (all_tokens ), context_length , dtype = torch .int )
150+
151+ for i , tokens in enumerate (all_tokens ):
152+ if len (tokens ) > context_length :
153+ if truncate :
154+ tokens = tokens [:context_length ]
155+ tokens [- 1 ] = eot_token
156+ else :
157+ raise RuntimeError (f"Input { texts [i ]} is too long for context length { context_length } " )
158+ result [i , :len (tokens )] = torch .tensor (tokens )
159+
160+ return result
161+
162+ def forward (self , text ):
163+ x = self .token_embedding (text ) # [batch_size, n_ctx, d_model]
164+
165+ x = x + self .positional_embedding
166+ x = x .permute (1 , 0 , 2 ) # NLD -> LND
167+ x = self .transformer (x )
168+ x = x .permute (1 , 0 , 2 ) # LND -> NLD
169+ x = self .ln_final (x )
170+
171+ # x.shape = [batch_size, n_ctx, transformer.width]
172+ # take features from the eot embedding (eot_token is the highest number in each sequence)
173+ x = x [torch .arange (x .shape [0 ]), text .argmax (dim = - 1 )] @ self .text_projection
174+
175+ return x
176+
177+
178+ def clip_text_b16 (
179+ embed_dim = 512 ,
180+ context_length = 77 ,
181+ vocab_size = 49408 ,
182+ transformer_width = 512 ,
183+ transformer_heads = 8 ,
184+ transformer_layers = 12 ,
185+ checkpoint_num = 0 ,
186+ pretrained = True ,
187+ ):
188+ # raise NotImplementedError
189+ model = CLIP_TEXT (
190+ embed_dim ,
191+ context_length ,
192+ vocab_size ,
193+ transformer_width ,
194+ transformer_heads ,
195+ transformer_layers ,
196+ checkpoint_num ,
197+ )
198+ # pretrained = _MODELS["ViT-B/16"]
199+ # logger.info(f"Load pretrained weights from {pretrained}")
200+ # state_dict = torch.load(pretrained, map_location='cpu')
201+ # model.load_state_dict(state_dict, strict=False)
202+ # return model.eval()
203+ if pretrained :
204+ if isinstance (pretrained , str ) and pretrained != "bert-base-uncased" :
205+ pretrained = _MODELS [pretrained ]
206+ else :
207+ pretrained = _MODELS ["ViT-B/16" ]
208+ logger .info (f"Load pretrained weights from { pretrained } " )
209+ state_dict = torch .load (pretrained , map_location = 'cpu' )
210+ if context_length != state_dict ["positional_embedding" ].size (0 ):
211+ # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
212+ print (f"Resize positional embedding from { state_dict ['positional_embedding' ].size (0 )} to { context_length } " )
213+ if context_length < state_dict ["positional_embedding" ].size (0 ):
214+ state_dict ["positional_embedding" ] = state_dict ["positional_embedding" ][:context_length ]
215+ else :
216+ state_dict ["positional_embedding" ] = F .pad (
217+ state_dict ["positional_embedding" ],
218+ (0 , 0 , 0 , context_length - state_dict ["positional_embedding" ].size (0 )),
219+ value = 0 ,
220+ )
221+
222+ message = model .load_state_dict (state_dict , strict = False )
223+ print (f"Load pretrained weights from { pretrained } : { message } " )
224+ return model .eval ()
225+
226+
227+ def clip_text_l14 (
228+ embed_dim = 768 ,
229+ context_length = 77 ,
230+ vocab_size = 49408 ,
231+ transformer_width = 768 ,
232+ transformer_heads = 12 ,
233+ transformer_layers = 12 ,
234+ checkpoint_num = 0 ,
235+ pretrained = True ,
236+ ):
237+ model = CLIP_TEXT (
238+ embed_dim ,
239+ context_length ,
240+ vocab_size ,
241+ transformer_width ,
242+ transformer_heads ,
243+ transformer_layers ,
244+ checkpoint_num ,
245+ )
246+ if pretrained :
247+ if isinstance (pretrained , str ) and pretrained != "bert-base-uncased" :
248+ pretrained = _MODELS [pretrained ]
249+ else :
250+ pretrained = _MODELS ["ViT-L/14" ]
251+ logger .info (f"Load pretrained weights from { pretrained } " )
252+ state_dict = torch .load (pretrained , map_location = 'cpu' )
253+ if context_length != state_dict ["positional_embedding" ].size (0 ):
254+ # assert context_length < state_dict["positional_embedding"].size(0), "Cannot increase context length."
255+ print (f"Resize positional embedding from { state_dict ['positional_embedding' ].size (0 )} to { context_length } " )
256+ if context_length < state_dict ["positional_embedding" ].size (0 ):
257+ state_dict ["positional_embedding" ] = state_dict ["positional_embedding" ][:context_length ]
258+ else :
259+ state_dict ["positional_embedding" ] = F .pad (
260+ state_dict ["positional_embedding" ],
261+ (0 , 0 , 0 , context_length - state_dict ["positional_embedding" ].size (0 )),
262+ value = 0 ,
263+ )
264+
265+ message = model .load_state_dict (state_dict , strict = False )
266+ print (f"Load pretrained weights from { pretrained } : { message } " )
267+ return model .eval ()
268+
269+
270+ def clip_text_l14_336 (
271+ embed_dim = 768 ,
272+ context_length = 77 ,
273+ vocab_size = 49408 ,
274+ transformer_width = 768 ,
275+ transformer_heads = 12 ,
276+ transformer_layers = 12 ,
277+ ):
278+ raise NotImplementedError
279+ model = CLIP_TEXT (
280+ embed_dim ,
281+ context_length ,
282+ vocab_size ,
283+ transformer_width ,
284+ transformer_heads ,
285+ transformer_layers
286+ )
287+ pretrained = _MODELS ["ViT-L/14_336" ]
288+ logger .info (f"Load pretrained weights from { pretrained } " )
289+ state_dict = torch .load (pretrained , map_location = 'cpu' )
290+ model .load_state_dict (state_dict , strict = False )
291+ return model .eval ()
292+
79293
80- ret_texts = [texts [i ] for i in idxs .numpy ()[0 ].tolist ()]
81- return ret_texts , probs .numpy ()[0 ]
294+ def build_clip (config ):
295+ model_cls = config .text_encoder .clip_teacher
296+ model = eval (model_cls )()
297+ return model
0 commit comments