-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFID_score.py
More file actions
58 lines (45 loc) · 1.88 KB
/
FID_score.py
File metadata and controls
58 lines (45 loc) · 1.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
from torchvision import transforms,models
import torch
from scipy.linalg import sqrtm
import numpy as np
from os import getcwd
from torchvision.models.feature_extraction import get_graph_node_names
from torchvision.models.feature_extraction import create_feature_extractor
cwd = getcwd()
def calculate_fid(train, target):
#Functions for calculating FID
preprocess = transforms.Compose([
transforms.Resize(299),
transforms.CenterCrop(299),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
inception_mdl = torch.hub.load('pytorch/vision:v0.10.0', 'inception_v3', pretrained=True)
inception_mdl.eval()
is_cuda = torch.cuda.is_available()
if is_cuda:
inception_mdl = inception_mdl.cuda()
train_nodes, eval_nodes = get_graph_node_names(inception_mdl)
# remove the last layer
return_nodes = eval_nodes[:-1]
# create a feature extractor for each intermediary layer
feat_inception = create_feature_extractor(inception_mdl, return_nodes=return_nodes)
if is_cuda:
feat_inception = feat_inception.cuda()
train = preprocess(train)
target = preprocess(target)
train = feat_inception(train)
target = feat_inception(target)
target = target['flatten'].cpu().detach().numpy()
train = train['flatten'].cpu().detach().numpy()
mu1, sigma1 = train.mean(axis=0), np.cov(train, rowvar=False)
mu2, sigma2 = target.mean(axis=0), np.cov(target, rowvar=False)
# calculate sum squared difference between means
ssdiff = np.sum((mu1 - mu2)**2.0)
# calculate sqrt of product between cov
covmean = sqrtm(sigma1.dot(sigma2))
# check and correct imaginary numbers from sqrt
if np.iscomplexobj(covmean):
covmean = covmean.real
# calculate score
fid = ssdiff + np.trace(sigma1 + sigma2 - 2.0 * covmean)
return fid