-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain_robust.py
More file actions
59 lines (49 loc) · 2.22 KB
/
main_robust.py
File metadata and controls
59 lines (49 loc) · 2.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import os
import pandas as pd
import torch
from pathlib import Path
import argparse
import json
# Local imports
from dataset_loaders import get_dataset
from model_builders import build_openclip_text
from dataset_loaders import get_dataset
from main_ood import get_args, get_ood_data, apply_all, create_out_dir, get_corpus
def get_indist_test(arch, indist_names=["IN_V2" , "IN_A", "IN_R", "IN_C", "sketch", "IN1K"]):
datasets = []
for ds in indist_names:
data = get_dataset(ds, train=False, precompute_arch=arch)
datasets.append(data)
return datasets, indist_names
def main_robust():
proj= None
args = get_args()
args.ood_names = ['NINCO', 'IN_O', "openimage_o",'inat', 'IN21OOD', "texturev2"]
print(f"\n\n Running robustness analysis for {args.ood_names} \n\n")
# indist train set
indist_train = get_dataset(args.dataset, train=True, precompute_arch=args.arch)
# ood test sets
ood_datasets = get_ood_data(args.arch, ood_names=args.ood_names)
# indist test sets
indist_tests, indist_names = get_indist_test(args.arch)
args.indist_names = indist_names
# Load tensor files for corpus and text embeddings
corpus_embeds = get_corpus(args.arch, args.corpus_name, args.prompt)
text_embeds_in1k = indist_train.get_corpus(prompt=args.prompt)
for (indist_test, indist_name) in zip(indist_tests, indist_names):
print(f"\n\n Running analysis for {indist_name} \n\n")
args.out_dir = create_out_dir(args, indist_name=indist_name)
dfs_auroc = []
dfs_fpr = []
dfs_auroc, dfs_fpr = apply_all(args, indist_train, indist_test, ood_datasets,
corpus_embeds, text_embeds_in1k, proj, display_results=False)
for (metric, df_list) in [("AUROC",dfs_auroc), ("FPR95",dfs_fpr)]:
if len(df_list) == 0:
continue
df = pd.concat([df for df in df_list], axis=0)
# Add another column called average that averages all column values per row
df["average"] = df.mean(axis=1)
# save the dataframe
df.to_csv(args.out_dir / f"{metric}_{args.methodname}.csv")
if __name__ == '__main__':
main_robust()