Skip to content

Commit 4569129

Browse files
Merge branch 'refactor' of https://github.com/ffhibnese/Model-Inversion-Attack-ToolBox into refactor
1 parent 5d21e32 commit 4569129

File tree

6 files changed

+703
-701
lines changed

6 files changed

+703
-701
lines changed
Lines changed: 67 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,67 +1,68 @@
1-
#!/usr/bin/env python3
2-
# coding=utf-8
3-
import torch
4-
from torch import nn
5-
import torchvision.models as models
6-
7-
# import vgg_m_face_bn_dag
8-
from .....models import *
9-
# import net_sphere
10-
import os
11-
12-
from collections import OrderedDict
13-
def get_model(arch_name, device, classifier_dir, dataset_name, use_dropout=False):
14-
15-
path = os.path.join(classifier_dir, 'target_eval', dataset_name)
16-
# path = None
17-
# state_dict = torch
18-
if arch_name == 'vgg16':
19-
path = os.path.join(path, 'VGG16_88.26.tar')
20-
model = (VGG16(1000))
21-
elif arch_name == 'ir152':
22-
path = os.path.join(path, 'IR152_91.16.tar')
23-
model = (IR152(1000))
24-
elif arch_name == 'facenet64':
25-
path = os.path.join(path, 'FaceNet64_88.50.tar')
26-
model = (FaceNet64(1000))
27-
elif arch_name == 'facenet':
28-
path = os.path.join(path, 'FaceNet_95.88.tar')
29-
model = (FaceNet(1000))
30-
elif arch_name =='resnet50_scratch_dag':
31-
path = os.path.join(path, 'resnet50_scratch_dag.pth')
32-
model = Resnet50_scratch_dag()
33-
elif arch_name == 'vgg_face_dag':
34-
path = os.path.join(path, 'vgg_face_dag.pth')
35-
model = Vgg_face_dag(use_dropout=use_dropout)
36-
elif arch_name == 'inception_resnetv1_vggface2':
37-
path = os.path.join(path, '20180402-114759-vggface2.pt')
38-
# model = InceptionResnetV1(classify=True, pretrained='vggface2', ckpt_path=path)
39-
model = InceptionResnetV1(classify=True, pretrained=None, ckpt_path=path, num_classes=8631)
40-
# elif arch_name == 'inception_resnetv1_casia':
41-
# model = InceptionResnetV1(classify=True, pretrained='casia-webface')
42-
else:
43-
raise RuntimeError('arch name error')
44-
# print(f'>>>>>>>>> arch name: {arch_name} path {path}')
45-
46-
if path is not None and os.path.isfile(path):
47-
48-
state_dict = torch.load(path, map_location=device)
49-
50-
if isinstance(state_dict, dict) and 'state_dict' in state_dict.keys():
51-
state_dict = state_dict['state_dict']
52-
53-
54-
# new_state_dict = OrderedDict()
55-
# for k, v in state_dict.items():
56-
# name = k
57-
# if k.startswith('module.'):
58-
# pl = nn.DataParallel(model)
59-
# pl.load_state_dict(state_dict)
60-
# torch.save({'state_dict': pl.module.state_dict()}, path)
61-
# state_dict = torch.load(path)['state_dict']
62-
# break
63-
model.load_state_dict(state_dict)
64-
else:
65-
print(path)
66-
raise RuntimeError('not checkpoint')
1+
2+
#!/usr/bin/env python3
3+
# coding=utf-8
4+
import torch
5+
from torch import nn
6+
import torchvision.models as models
7+
8+
# import vgg_m_face_bn_dag
9+
from .....models import *
10+
# import net_sphere
11+
import os
12+
13+
from collections import OrderedDict
14+
def get_model(arch_name, device, classifier_dir, dataset_name, use_dropout=False):
15+
16+
path = os.path.join(classifier_dir, 'target_eval', dataset_name)
17+
# path = None
18+
# state_dict = torch
19+
if arch_name == 'vgg16':
20+
path = os.path.join(path, 'VGG16_88.26.tar')
21+
model = (VGG16(1000))
22+
elif arch_name == 'ir152':
23+
path = os.path.join(path, 'IR152_91.16.tar')
24+
model = (IR152(1000))
25+
elif arch_name == 'facenet64':
26+
path = os.path.join(path, 'FaceNet64_88.50.tar')
27+
model = (FaceNet64(1000))
28+
elif arch_name == 'facenet':
29+
path = os.path.join(path, 'FaceNet_95.88.tar')
30+
model = (FaceNet(1000))
31+
elif arch_name =='resnet50_scratch_dag':
32+
path = os.path.join(path, 'resnet50_scratch_dag.pth')
33+
model = Resnet50_scratch_dag()
34+
elif arch_name == 'vgg_face_dag':
35+
path = os.path.join(path, 'vgg_face_dag.pth')
36+
model = Vgg_face_dag(use_dropout=use_dropout)
37+
elif arch_name == 'inception_resnetv1_vggface2':
38+
path = os.path.join(path, '20180402-114759-vggface2.pt')
39+
# model = InceptionResnetV1(classify=True, pretrained='vggface2', ckpt_path=path)
40+
model = InceptionResnetV1(classify=True, pretrained=None, ckpt_path=path, num_classes=8631)
41+
# elif arch_name == 'inception_resnetv1_casia':
42+
# model = InceptionResnetV1(classify=True, pretrained='casia-webface')
43+
else:
44+
raise RuntimeError('arch name error')
45+
# print(f'>>>>>>>>> arch name: {arch_name} path {path}')
46+
47+
if path is not None and os.path.isfile(path):
48+
49+
state_dict = torch.load(path, map_location=device)
50+
51+
if isinstance(state_dict, dict) and 'state_dict' in state_dict.keys():
52+
state_dict = state_dict['state_dict']
53+
54+
55+
# new_state_dict = OrderedDict()
56+
# for k, v in state_dict.items():
57+
# name = k
58+
# if k.startswith('module.'):
59+
# pl = nn.DataParallel(model)
60+
# pl.load_state_dict(state_dict)
61+
# torch.save({'state_dict': pl.module.state_dict()}, path)
62+
# state_dict = torch.load(path)['state_dict']
63+
# break
64+
model.load_state_dict(state_dict)
65+
else:
66+
print(path)
67+
raise RuntimeError('not checkpoint')
6768
return model.to(device)
Lines changed: 99 additions & 98 deletions
Original file line numberDiff line numberDiff line change
@@ -1,99 +1,100 @@
1-
2-
"""This file is modifed from synthesize.py. The goal is to return a generator which output an image in range [0., 1.]"""
3-
4-
import os
5-
import argparse
6-
import subprocess
7-
from tqdm import tqdm
8-
import numpy as np
9-
10-
import torch
11-
from torchvision.utils import save_image
12-
13-
from .models import MODEL_ZOO
14-
from .models import build_generator, build_discriminator
15-
from .utils.misc import bool_parser
16-
from .utils.visualizer import HtmlPageVisualizer
17-
18-
def postprocess(images):
19-
"""change the range from [-1, 1] to [0., 1.]"""
20-
images = torch.clamp((images + 1.) / 2., 0., 1.)
21-
return images
22-
23-
def get_genforce(model_name, device, checkpoint_dir, use_discri=True, use_w_space=True, use_z_plus_space=False, repeat_w=True):
24-
25-
trunc_psi = 0.7
26-
trunc_layers = 8
27-
28-
if model_name not in MODEL_ZOO:
29-
raise RuntimeError(f'model name `{model_name}` is not in model zoo')
30-
model_config = MODEL_ZOO[model_name].copy()
31-
url = model_config.pop('url')
32-
33-
print(f'Building generator for model `{model_name}`')
34-
if model_name.startswith('stylegan'):
35-
generator = build_generator(**model_config, repeat_w=repeat_w)
36-
else:
37-
generator = build_generator(**model_config)
38-
synthesis_kwargs = dict(trunc_psi=trunc_psi,
39-
trunc_layers=trunc_layers)
40-
41-
# Build discriminator
42-
if use_discri:
43-
print(f'Building discriminator for model `{model_name}` ...')
44-
discriminator = build_discriminator(**model_config)
45-
else:
46-
discriminator = None
47-
48-
# load checkpoints
49-
os.makedirs(os.path.join(checkpoint_dir, 'genforce'), exist_ok=True)
50-
ckpt_path = os.path.join(checkpoint_dir, 'genforce', f'{model_name}.pth')
51-
52-
if not os.path.exists(ckpt_path):
53-
print(f'Download checkpoint {model_name} from {url} ...')
54-
subprocess.call(['wget', '--quiet', '-O', ckpt_path, url])
55-
56-
checkpoint = torch.load(ckpt_path)
57-
58-
if 'generator_smooth' in checkpoint:
59-
generator.load_state_dict(checkpoint['generator_smooth'])
60-
else:
61-
generator.load_state_dict(checkpoint['generator'])
62-
generator = generator.to(device)
63-
generator.eval()
64-
if use_discri:
65-
discriminator.load_state_dict(checkpoint['discriminator'])
66-
discriminator = discriminator.to(device)
67-
discriminator.eval()
68-
print('Finish loading checkpoint.')
69-
70-
def fake_generator(code):
71-
# Sample and synthesize.
72-
# print(f'Synthesizing {args.num} samples ...')
73-
# code = torch.randn(args.batch_size, generator.z_space_dim).cuda()
74-
if use_z_plus_space:
75-
code = generator.mapping(code)['w']
76-
code = code.view(-1, generator.num_layers, generator.w_space_dim)
77-
images = generator(code, **synthesis_kwargs, use_w_space=use_w_space)['image']
78-
images = postprocess(images)
79-
# save_image(images, os.path.join(work_dir, 'tmp.png'), nrow=5)
80-
# print(f'Finish synthesizing {args.num} samples.')
81-
return images
82-
83-
return Fake_G(generator, fake_generator), discriminator
84-
85-
class Fake_G:
86-
87-
def __init__(self, G, g_function):
88-
self.G = G
89-
self.g_function = g_function
90-
91-
def __call__(self, code):
92-
# print(f'code.shape {code.shape}')
93-
return self.g_function(code)
94-
95-
def mapping(self, code, label=None):
96-
return self.G.mapping(code, label=None)
97-
98-
def zero_grad(self):
1+
2+
"""This file is modifed from synthesize.py. The goal is to return a generator which output an image in range [0., 1.]"""
3+
4+
5+
import os
6+
import argparse
7+
import subprocess
8+
from tqdm import tqdm
9+
import numpy as np
10+
11+
import torch
12+
from torchvision.utils import save_image
13+
14+
from .models import MODEL_ZOO
15+
from .models import build_generator, build_discriminator
16+
from .utils.misc import bool_parser
17+
from .utils.visualizer import HtmlPageVisualizer
18+
19+
def postprocess(images):
20+
"""change the range from [-1, 1] to [0., 1.]"""
21+
images = torch.clamp((images + 1.) / 2., 0., 1.)
22+
return images
23+
24+
def get_genforce(model_name, device, checkpoint_dir, use_discri=True, use_w_space=True, use_z_plus_space=False, repeat_w=True):
25+
26+
trunc_psi = 0.7
27+
trunc_layers = 8
28+
29+
if model_name not in MODEL_ZOO:
30+
raise RuntimeError(f'model name `{model_name}` is not in model zoo')
31+
model_config = MODEL_ZOO[model_name].copy()
32+
url = model_config.pop('url')
33+
34+
print(f'Building generator for model `{model_name}`')
35+
if model_name.startswith('stylegan'):
36+
generator = build_generator(**model_config, repeat_w=repeat_w)
37+
else:
38+
generator = build_generator(**model_config)
39+
synthesis_kwargs = dict(trunc_psi=trunc_psi,
40+
trunc_layers=trunc_layers)
41+
42+
# Build discriminator
43+
if use_discri:
44+
print(f'Building discriminator for model `{model_name}` ...')
45+
discriminator = build_discriminator(**model_config)
46+
else:
47+
discriminator = None
48+
49+
# load checkpoints
50+
os.makedirs(os.path.join(checkpoint_dir, 'genforce'), exist_ok=True)
51+
ckpt_path = os.path.join(checkpoint_dir, 'genforce', f'{model_name}.pth')
52+
53+
if not os.path.exists(ckpt_path):
54+
print(f'Download checkpoint {model_name} from {url} ...')
55+
subprocess.call(['wget', '--quiet', '-O', ckpt_path, url])
56+
57+
checkpoint = torch.load(ckpt_path)
58+
59+
if 'generator_smooth' in checkpoint:
60+
generator.load_state_dict(checkpoint['generator_smooth'])
61+
else:
62+
generator.load_state_dict(checkpoint['generator'])
63+
generator = generator.to(device)
64+
generator.eval()
65+
if use_discri:
66+
discriminator.load_state_dict(checkpoint['discriminator'])
67+
discriminator = discriminator.to(device)
68+
discriminator.eval()
69+
print('Finish loading checkpoint.')
70+
71+
def fake_generator(code):
72+
# Sample and synthesize.
73+
# print(f'Synthesizing {args.num} samples ...')
74+
# code = torch.randn(args.batch_size, generator.z_space_dim).cuda()
75+
if use_z_plus_space:
76+
code = generator.mapping(code)['w']
77+
code = code.view(-1, generator.num_layers, generator.w_space_dim)
78+
images = generator(code, **synthesis_kwargs, use_w_space=use_w_space)['image']
79+
images = postprocess(images)
80+
# save_image(images, os.path.join(work_dir, 'tmp.png'), nrow=5)
81+
# print(f'Finish synthesizing {args.num} samples.')
82+
return images
83+
84+
return Fake_G(generator, fake_generator), discriminator
85+
86+
class Fake_G:
87+
88+
def __init__(self, G, g_function):
89+
self.G = G
90+
self.g_function = g_function
91+
92+
def __call__(self, code):
93+
# print(f'code.shape {code.shape}')
94+
return self.g_function(code)
95+
96+
def mapping(self, code, label=None):
97+
return self.G.mapping(code, label=None)
98+
99+
def zero_grad(self):
99100
self.G.zero_grad()
Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
1-
2-
3-
import os
4-
5-
def create_folder(folder):
6-
if os.path.exists(folder):
7-
assert os.path.isdir(folder), 'it exists but is not a folder'
8-
else:
1+
2+
import os
3+
4+
def create_folder(folder):
5+
if os.path.exists(folder):
6+
assert os.path.isdir(folder), 'it exists but is not a folder'
7+
else:
98
os.makedirs(folder)
Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,23 @@
1-
import torch
2-
from torch import nn
3-
4-
def verify_acc(inputs, labels, model, arch_name):
5-
6-
device = inputs.device
7-
8-
acc = 0
9-
10-
with torch.no_grad():
11-
12-
pred = model(inputs)
13-
if arch_name == 'sphere20a':
14-
pred = pred.result
15-
else:
16-
pred = pred.result
17-
18-
confidence = nn.functional.softmax(pred, dim=-1)
19-
pred_label = torch.argmax(pred, dim=-1)
20-
acc = (pred_label == labels).sum() / len(labels)
21-
1+
import torch
2+
from torch import nn
3+
4+
5+
def verify_acc(inputs, labels, model, arch_name):
6+
7+
device = inputs.device
8+
9+
acc = 0
10+
11+
with torch.no_grad():
12+
13+
pred = model(inputs)
14+
if arch_name == 'sphere20a':
15+
pred = pred.result
16+
else:
17+
pred = pred.result
18+
19+
confidence = nn.functional.softmax(pred, dim=-1)
20+
pred_label = torch.argmax(pred, dim=-1)
21+
acc = (pred_label == labels).sum() / len(labels)
22+
2223
return acc

0 commit comments

Comments
 (0)