Skip to content

Commit 6e937da

Browse files
czczupcg1177
andauthored
Add MSRVTT test script (#123)
Co-authored-by: Guo Chen <[email protected]>
1 parent 30bf562 commit 6e937da

File tree

1 file changed

+161
-0
lines changed

1 file changed

+161
-0
lines changed

video_retrieval/test_msrvtt.py

Lines changed: 161 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,161 @@
1+
import torch
2+
from PIL import Image
3+
from transformers import AutoModel, CLIPImageProcessor
4+
from transformers import AutoTokenizer
5+
6+
import math
7+
import numpy as np
8+
import io
9+
import os
10+
import json
11+
import mmengine
12+
import decord
13+
import tqdm
14+
import argparse
15+
16+
17+
18+
def recall_at_k(scores, positive_pairs, k):
19+
"""
20+
Compute the recall at k for each sample
21+
:param scores: compability score between text and image embeddings (nb texts, nb images)
22+
:param k: number of images to consider per text, for retrieval
23+
:param positive_pairs: boolean matrix of positive pairs (nb texts, nb images)
24+
:return: recall at k averaged over all texts
25+
"""
26+
nb_texts, nb_images = scores.shape
27+
# for each text, sort according to image scores in decreasing order
28+
topk_indices = torch.topk(scores, k, dim=1)[1]
29+
# compute number of positives for each text
30+
nb_positive = positive_pairs.sum(dim=1)
31+
# nb_texts, k, nb_images
32+
topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images)
33+
# compute number of true positives
34+
positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images)
35+
# a true positive means a positive among the topk
36+
nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1, 2))
37+
# compute recall at k
38+
recall_at_k = (nb_true_positive / nb_positive)
39+
return recall_at_k
40+
41+
42+
def batchify(func, X, Y, batch_size, device, *args, **kwargs):
43+
results = []
44+
for start in range(0, len(X), batch_size):
45+
end = start + batch_size
46+
x = X[start:end].to(device)
47+
y = Y[start:end].to(device)
48+
result = func(x, y, *args, **kwargs).cpu()
49+
results.append(result)
50+
return torch.cat(results)
51+
52+
def validate_msrvtt(model, tokenizer, image_processor, root, metadata,
53+
num_frames=1, prefix="summarize:", mode="InternVL-G", recall_k_list=[1, 5, 10],
54+
use_dsl=True, eval_batch_size=32):
55+
metadata = json.load(open(metadata))
56+
57+
video_features = []
58+
text_features = []
59+
60+
# compute text features
61+
print("Computing text features", flush=True)
62+
for data in tqdm.tqdm(metadata):
63+
caption = prefix + data["caption"]
64+
input_ids = tokenizer(caption, return_tensors='pt', max_length=80,
65+
truncation=True, padding='max_length').input_ids.cuda()
66+
with torch.no_grad():
67+
feat = model.encode_text(input_ids)
68+
text_features.append(feat.cpu())
69+
text_features = torch.cat(text_features)
70+
71+
# compute video features
72+
print("Computing video features", flush=True)
73+
for data in tqdm.tqdm(metadata):
74+
video_id = data["video"]
75+
video_path = os.path.join(root, video_id)
76+
video_data = mmengine.get(video_path)
77+
video_data = io.BytesIO(video_data)
78+
video_reader = decord.VideoReader(video_data)
79+
80+
# uniformly sample frames
81+
interval = math.ceil(len(video_reader) / num_frames)
82+
frames_id = np.arange(0, len(video_reader), interval) + interval // 2
83+
assert len(frames_id) == num_frames and frames_id[-1] < len(video_reader)
84+
85+
frames = video_reader.get_batch(frames_id).asnumpy()
86+
87+
pixel_values = image_processor(images=frames, return_tensors='pt').pixel_values
88+
with torch.no_grad():
89+
pixel_values = pixel_values.to(torch.bfloat16).cuda()
90+
feat = model.encode_image(pixel_values, mode=mode)
91+
feat = feat.mean(dim=0, keepdim=True)
92+
video_features.append(feat.cpu())
93+
video_features = torch.cat(video_features)
94+
95+
print("Computing metrics", flush=True)
96+
texts_emb = text_features / text_features.norm(dim=-1, keepdim=True)
97+
images_emb = video_features / video_features.norm(dim=-1, keepdim=True)
98+
99+
# get the score for each text and image pair
100+
scores = texts_emb @ images_emb.t()
101+
102+
# construct a the positive pair matrix, which tells whether each text-image pair is a positive or not
103+
positive_pairs = torch.zeros_like(scores, dtype=bool)
104+
positive_pairs[torch.arange(len(scores)), torch.arange(len(scores))] = True
105+
106+
scores_T = scores.T
107+
positive_pairs_T = positive_pairs.T
108+
109+
if use_dsl:
110+
scores = scores * scores.softmax(dim=0)
111+
scores_T = scores_T * scores_T.softmax(dim=0)
112+
113+
metrics = {}
114+
for recall_k in recall_k_list:
115+
# Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number
116+
# of true positives, e.g. for text retrieval, is, for each image, the number of retrieved texts matching that image among the top-k.
117+
# Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions
118+
# for each image, that number will be greater than 1 for text retrieval.
119+
# However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different.
120+
# recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k.
121+
# so we can easily compute that using the actual recall, by checking whether there is at least one true positive,
122+
# which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average
123+
# it over the dataset.
124+
metrics[f't2v_retrieval_recall@{recall_k}'] = (
125+
batchify(recall_at_k, scores, positive_pairs, eval_batch_size, scores.device,
126+
k=recall_k) > 0).float().mean().item()
127+
metrics[f'v2t_retrieval_recall@{recall_k}'] = (
128+
batchify(recall_at_k, scores_T, positive_pairs_T, eval_batch_size, scores.device,
129+
k=recall_k) > 0).float().mean().item()
130+
131+
print(metrics)
132+
133+
134+
135+
if __name__ == "__main__":
136+
parser = argparse.ArgumentParser(description='validate MSR-VTT', add_help=False)
137+
parser.add_argument('--video-root', type=str)
138+
parser.add_argument('--metadata', type=str)
139+
parser.add_argument('--mode', type=str, default="InternVL-C",choices=["InternVL-C", "InternVL-G"])
140+
parser.add_argument('--num-frames', type=int, default=1)
141+
args = parser.parse_args()
142+
143+
model = AutoModel.from_pretrained(
144+
'OpenGVLab/InternVL-14B-224px',
145+
torch_dtype=torch.bfloat16,
146+
low_cpu_mem_usage=True,
147+
trust_remote_code=True).cuda().eval()
148+
149+
image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternVL-14B-224px')
150+
151+
tokenizer = AutoTokenizer.from_pretrained(
152+
'OpenGVLab/InternVL-14B-224px', use_fast=False, add_eos_token=True)
153+
tokenizer.pad_token_id = 0 # set pad_token_id to 0
154+
155+
156+
157+
metrics = validate_msrvtt(model, tokenizer, image_processor,
158+
root=args.video_root,
159+
metadata=args.metadata,
160+
mode=args.mode,
161+
num_frames=args.num_frames,)

0 commit comments

Comments
 (0)