|
| 1 | +import torch |
| 2 | +from PIL import Image |
| 3 | +from transformers import AutoModel, CLIPImageProcessor |
| 4 | +from transformers import AutoTokenizer |
| 5 | + |
| 6 | +import math |
| 7 | +import numpy as np |
| 8 | +import io |
| 9 | +import os |
| 10 | +import json |
| 11 | +import mmengine |
| 12 | +import decord |
| 13 | +import tqdm |
| 14 | +import argparse |
| 15 | + |
| 16 | + |
| 17 | + |
| 18 | +def recall_at_k(scores, positive_pairs, k): |
| 19 | + """ |
| 20 | + Compute the recall at k for each sample |
| 21 | + :param scores: compability score between text and image embeddings (nb texts, nb images) |
| 22 | + :param k: number of images to consider per text, for retrieval |
| 23 | + :param positive_pairs: boolean matrix of positive pairs (nb texts, nb images) |
| 24 | + :return: recall at k averaged over all texts |
| 25 | + """ |
| 26 | + nb_texts, nb_images = scores.shape |
| 27 | + # for each text, sort according to image scores in decreasing order |
| 28 | + topk_indices = torch.topk(scores, k, dim=1)[1] |
| 29 | + # compute number of positives for each text |
| 30 | + nb_positive = positive_pairs.sum(dim=1) |
| 31 | + # nb_texts, k, nb_images |
| 32 | + topk_indices_onehot = torch.nn.functional.one_hot(topk_indices, num_classes=nb_images) |
| 33 | + # compute number of true positives |
| 34 | + positive_pairs_reshaped = positive_pairs.view(nb_texts, 1, nb_images) |
| 35 | + # a true positive means a positive among the topk |
| 36 | + nb_true_positive = (topk_indices_onehot * positive_pairs_reshaped).sum(dim=(1, 2)) |
| 37 | + # compute recall at k |
| 38 | + recall_at_k = (nb_true_positive / nb_positive) |
| 39 | + return recall_at_k |
| 40 | + |
| 41 | + |
| 42 | +def batchify(func, X, Y, batch_size, device, *args, **kwargs): |
| 43 | + results = [] |
| 44 | + for start in range(0, len(X), batch_size): |
| 45 | + end = start + batch_size |
| 46 | + x = X[start:end].to(device) |
| 47 | + y = Y[start:end].to(device) |
| 48 | + result = func(x, y, *args, **kwargs).cpu() |
| 49 | + results.append(result) |
| 50 | + return torch.cat(results) |
| 51 | + |
| 52 | +def validate_msrvtt(model, tokenizer, image_processor, root, metadata, |
| 53 | + num_frames=1, prefix="summarize:", mode="InternVL-G", recall_k_list=[1, 5, 10], |
| 54 | + use_dsl=True, eval_batch_size=32): |
| 55 | + metadata = json.load(open(metadata)) |
| 56 | + |
| 57 | + video_features = [] |
| 58 | + text_features = [] |
| 59 | + |
| 60 | + # compute text features |
| 61 | + print("Computing text features", flush=True) |
| 62 | + for data in tqdm.tqdm(metadata): |
| 63 | + caption = prefix + data["caption"] |
| 64 | + input_ids = tokenizer(caption, return_tensors='pt', max_length=80, |
| 65 | + truncation=True, padding='max_length').input_ids.cuda() |
| 66 | + with torch.no_grad(): |
| 67 | + feat = model.encode_text(input_ids) |
| 68 | + text_features.append(feat.cpu()) |
| 69 | + text_features = torch.cat(text_features) |
| 70 | + |
| 71 | + # compute video features |
| 72 | + print("Computing video features", flush=True) |
| 73 | + for data in tqdm.tqdm(metadata): |
| 74 | + video_id = data["video"] |
| 75 | + video_path = os.path.join(root, video_id) |
| 76 | + video_data = mmengine.get(video_path) |
| 77 | + video_data = io.BytesIO(video_data) |
| 78 | + video_reader = decord.VideoReader(video_data) |
| 79 | + |
| 80 | + # uniformly sample frames |
| 81 | + interval = math.ceil(len(video_reader) / num_frames) |
| 82 | + frames_id = np.arange(0, len(video_reader), interval) + interval // 2 |
| 83 | + assert len(frames_id) == num_frames and frames_id[-1] < len(video_reader) |
| 84 | + |
| 85 | + frames = video_reader.get_batch(frames_id).asnumpy() |
| 86 | + |
| 87 | + pixel_values = image_processor(images=frames, return_tensors='pt').pixel_values |
| 88 | + with torch.no_grad(): |
| 89 | + pixel_values = pixel_values.to(torch.bfloat16).cuda() |
| 90 | + feat = model.encode_image(pixel_values, mode=mode) |
| 91 | + feat = feat.mean(dim=0, keepdim=True) |
| 92 | + video_features.append(feat.cpu()) |
| 93 | + video_features = torch.cat(video_features) |
| 94 | + |
| 95 | + print("Computing metrics", flush=True) |
| 96 | + texts_emb = text_features / text_features.norm(dim=-1, keepdim=True) |
| 97 | + images_emb = video_features / video_features.norm(dim=-1, keepdim=True) |
| 98 | + |
| 99 | + # get the score for each text and image pair |
| 100 | + scores = texts_emb @ images_emb.t() |
| 101 | + |
| 102 | + # construct a the positive pair matrix, which tells whether each text-image pair is a positive or not |
| 103 | + positive_pairs = torch.zeros_like(scores, dtype=bool) |
| 104 | + positive_pairs[torch.arange(len(scores)), torch.arange(len(scores))] = True |
| 105 | + |
| 106 | + scores_T = scores.T |
| 107 | + positive_pairs_T = positive_pairs.T |
| 108 | + |
| 109 | + if use_dsl: |
| 110 | + scores = scores * scores.softmax(dim=0) |
| 111 | + scores_T = scores_T * scores_T.softmax(dim=0) |
| 112 | + |
| 113 | + metrics = {} |
| 114 | + for recall_k in recall_k_list: |
| 115 | + # Note that recall_at_k computes **actual** recall i.e. nb_true_positive/nb_positives, where the number |
| 116 | + # of true positives, e.g. for text retrieval, is, for each image, the number of retrieved texts matching that image among the top-k. |
| 117 | + # Also, the number of positives are the total number of texts matching the image in the dataset, as we have a set of captions |
| 118 | + # for each image, that number will be greater than 1 for text retrieval. |
| 119 | + # However, image/text retrieval recall@k, the way it is done in CLIP-like papers, is a bit different. |
| 120 | + # recall@k, in CLIP-like papers, is, for each image, either 1 or 0. It is 1 if atleast one text matches the image among the top-k. |
| 121 | + # so we can easily compute that using the actual recall, by checking whether there is at least one true positive, |
| 122 | + # which would be the case if the recall is greater than 0. One we compute the recal for each image (or text), we average |
| 123 | + # it over the dataset. |
| 124 | + metrics[f't2v_retrieval_recall@{recall_k}'] = ( |
| 125 | + batchify(recall_at_k, scores, positive_pairs, eval_batch_size, scores.device, |
| 126 | + k=recall_k) > 0).float().mean().item() |
| 127 | + metrics[f'v2t_retrieval_recall@{recall_k}'] = ( |
| 128 | + batchify(recall_at_k, scores_T, positive_pairs_T, eval_batch_size, scores.device, |
| 129 | + k=recall_k) > 0).float().mean().item() |
| 130 | + |
| 131 | + print(metrics) |
| 132 | + |
| 133 | + |
| 134 | + |
| 135 | +if __name__ == "__main__": |
| 136 | + parser = argparse.ArgumentParser(description='validate MSR-VTT', add_help=False) |
| 137 | + parser.add_argument('--video-root', type=str) |
| 138 | + parser.add_argument('--metadata', type=str) |
| 139 | + parser.add_argument('--mode', type=str, default="InternVL-C",choices=["InternVL-C", "InternVL-G"]) |
| 140 | + parser.add_argument('--num-frames', type=int, default=1) |
| 141 | + args = parser.parse_args() |
| 142 | + |
| 143 | + model = AutoModel.from_pretrained( |
| 144 | + 'OpenGVLab/InternVL-14B-224px', |
| 145 | + torch_dtype=torch.bfloat16, |
| 146 | + low_cpu_mem_usage=True, |
| 147 | + trust_remote_code=True).cuda().eval() |
| 148 | + |
| 149 | + image_processor = CLIPImageProcessor.from_pretrained('OpenGVLab/InternVL-14B-224px') |
| 150 | + |
| 151 | + tokenizer = AutoTokenizer.from_pretrained( |
| 152 | + 'OpenGVLab/InternVL-14B-224px', use_fast=False, add_eos_token=True) |
| 153 | + tokenizer.pad_token_id = 0 # set pad_token_id to 0 |
| 154 | + |
| 155 | + |
| 156 | + |
| 157 | + metrics = validate_msrvtt(model, tokenizer, image_processor, |
| 158 | + root=args.video_root, |
| 159 | + metadata=args.metadata, |
| 160 | + mode=args.mode, |
| 161 | + num_frames=args.num_frames,) |
0 commit comments