Skip to content

Commit 0d08a19

Browse files
committed
adding comments
1 parent 69efdcc commit 0d08a19

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

clip_benchmark/models/transformers_clip.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import torch
12
from torch import nn
23
from transformers import AutoModel, AutoProcessor
34
from functools import partial
@@ -11,14 +12,19 @@ def encode_text(self, text):
1112
return self.model.get_text_features(**text)
1213

1314
def encode_image(self, image):
14-
return self.model.get_image_features(image["pixel_values"].squeeze(1))
15+
# we get an extended dimension possibly due to the collation in dataloader
16+
image = {key: value.squeeze(1) for key, value in image.items()}
17+
return self.model.get_image_features(**image)
1518

1619
def load_transformers_clip(model_name, pretrained, cache_dir, device):
1720
ckpt = f"{model_name}/{pretrained}"
1821
model = AutoModel.from_pretrained(ckpt, cache_dir=cache_dir, device_map=device)
1922
model = TransformerWrapper(model)
23+
2024
processor = AutoProcessor.from_pretrained(ckpt)
21-
22-
transforms = partial(processor.image_processor, return_tensors="pt")
23-
tokenizer = partial(processor.tokenizer, return_tensors="pt", padding="max_length")
25+
transforms = partial(processor.image_processor.preprocess, return_tensors="pt")
26+
tokenizer = partial(
27+
processor.tokenizer, return_tensors="pt", padding="max_length",
28+
max_length=64 # very specific to SG2
29+
)
2430
return model, transforms, tokenizer

0 commit comments

Comments
 (0)