1+ import torch
12from torch import nn
23from transformers import AutoModel , AutoProcessor
34from functools import partial
@@ -11,14 +12,19 @@ def encode_text(self, text):
1112 return self .model .get_text_features (** text )
1213
1314 def encode_image (self , image ):
14- return self .model .get_image_features (image ["pixel_values" ].squeeze (1 ))
15+ # we get an extended dimension possibly due to the collation in dataloader
16+ image = {key : value .squeeze (1 ) for key , value in image .items ()}
17+ return self .model .get_image_features (** image )
1518
1619def load_transformers_clip (model_name , pretrained , cache_dir , device ):
1720 ckpt = f"{ model_name } /{ pretrained } "
1821 model = AutoModel .from_pretrained (ckpt , cache_dir = cache_dir , device_map = device )
1922 model = TransformerWrapper (model )
23+
2024 processor = AutoProcessor .from_pretrained (ckpt )
21-
22- transforms = partial (processor .image_processor , return_tensors = "pt" )
23- tokenizer = partial (processor .tokenizer , return_tensors = "pt" , padding = "max_length" )
25+ transforms = partial (processor .image_processor .preprocess , return_tensors = "pt" )
26+ tokenizer = partial (
27+ processor .tokenizer , return_tensors = "pt" , padding = "max_length" ,
28+ max_length = 64 # very specific to SG2
29+ )
2430 return model , transforms , tokenizer
0 commit comments