-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgen_text_embeds.py
More file actions
87 lines (75 loc) · 3.97 KB
/
gen_text_embeds.py
File metadata and controls
87 lines (75 loc) · 3.97 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
from pathlib import Path
import torch
import json
import os
from tqdm import tqdm
import pandas as pd
import numpy as np
import argparse
from dataset_loaders import get_dataset, get_corpus_path, get_corpus_names
from model_builders import clip_label_embeddings, build_openclip_text
from model_builders.prompts import get_prompts,get_avail_prompts
from ood import zero_shot_clip
def generate_text_emb_all(dataset_names, context_prompts, arch="clip:0",
device="cuda", logit_scale=True, merge_lemmas=False,
step_size=256):
all_text_emb = []
for c in tqdm(range(0, len(dataset_names), step_size)):
if c+step_size > len(dataset_names):
ds = dataset_names[c:]
else:
ds = dataset_names[c:c+step_size]
batch_embeds, _ = clip_label_embeddings(arch, ds, context_prompts, device=device,
logit_scale=logit_scale, multi_prompt=merge_lemmas,
separator=", ")
assert len(ds) == len(batch_embeds), f"len(ds)={len(ds)} len(batch_embeds)={len(batch_embeds)}"
all_text_emb.append(batch_embeds.cpu())
return torch.cat(all_text_emb, axis=0)
def get_args():
parser = argparse.ArgumentParser(description="Zero shot evaluation with CLIP")
parser.add_argument("--arch", type=str, help="CLIP model")
parser.add_argument("--prompt", type=str, default="simple", choices=get_avail_prompts(), help="Prompt type")
parser.add_argument("--path_corpus", type=str, default=None, help="Path to csv names")
parser.add_argument("--corpus_name", type=str, default="IN1K", help="Path to csv names")
parser.add_argument("--overwrite", type=int, default=0)
args = parser.parse_args()
args.context_prompts = get_prompts(args.prompt)
return args
if __name__ == "__main__":
args = get_args()
if args.path_corpus is None:
args.path_corpus = get_corpus_path(args.arch, args.corpus_name, args.prompt)
corpus_class_names = get_corpus_names(args.corpus_name)
print(f'Corpus {args.corpus_name} loaded with {len(corpus_class_names)} class names')
basedir = args.path_corpus.parent
if os.path.exists(args.path_corpus) and args.overwrite == 0:
print(f"Embeddings for {args.arch} and prompt {args.prompt} already exist")
text_embeds = torch.load(args.path_corpus)
else:
basedir.mkdir(parents=True, exist_ok=True)
# Save args to json file
with open(str(basedir / "args.json"), "w") as f:
args_dict = vars(args)
args_dict = {key: str(value) if isinstance(value, Path) else value for key, value in args_dict.items()}
json.dump(args_dict, f, indent=4)
print(f"Generate embeds with context prompts {args.prompt}...")
text_embeds = generate_text_emb_all(corpus_class_names, args.context_prompts, arch=args.arch,
device="cuda", logit_scale=True, merge_lemmas=False)
torch.save(text_embeds, args.path_corpus)
# save context_prompts in a txt file
a = f"{args.arch}_context_prompts.txt"
path_out = str(basedir /a )
# convert list of str to a single str
context_prompts_str = "\n".join(args.context_prompts)
with open(path_out, "w") as f:
f.write(context_prompts_str)
print("Text prompts saved to ", path_out)
# total zero-shot ACC on Imagenet1k using the corpus
if args.corpus_name == "IN1K":
dataset = get_dataset(args.corpus_name, train=True, precompute_arch=args.arch, whiten=False)
print(f"Zero-shot evaluation on {args.corpus_name}...")
scores, preds = zero_shot_clip(text_embeds, dataset, step=512)
print(f"Total accuracy: {np.mean(scores)*100}")
# make a csv dataframe with the mean score
df = pd.DataFrame({"acc": [np.mean(scores)*100] })
df.to_csv(basedir / f"accuracy_zero_shot.csv")