-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathreconstruct.py
More file actions
112 lines (95 loc) · 4.19 KB
/
reconstruct.py
File metadata and controls
112 lines (95 loc) · 4.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
from core.models import GeminioResNet34
from core.dataset import CustomData
import breaching
import logging
import torch
import sys
import os
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stdout)], format='%(message)s')
logger = logging.getLogger()
# List of supported semantic queries
SUPPORTED_QUERIES = [
"Any jewelry?",
"Any human faces?",
"Any males with a beard?",
"Any guns?",
"Any females riding a horse?"
]
def reconstruct_image(cfg, setup, query=None):
"""
Reconstruct private training images using either baseline or query-based approach.
Args:
cfg: Configuration object containing model and training parameters
setup: Dictionary containing device and dtype settings
query: Optional semantic query string for targeted reconstruction
Returns:
None (Saves reconstructed images to disk)
"""
# Initialize model and components
model = GeminioResNet34(num_classes=cfg.case.data.classes)
user, server, model, loss_fn = breaching.cases.construct_case(cfg.case, model, setup)
# Load query-specific model if query is provided
if query:
if query not in SUPPORTED_QUERIES:
raise ValueError(
f"Query '{query}' is not supported. Supported queries are: {SUPPORTED_QUERIES}.\n"
"We will release the training script after acceptance of our paper."
)
model_path = f'./malicious_models/{query.replace(" ", "_").replace("?", "")}.pt'
model_state = torch.load(model_path)
if not any(k.startswith('clf.') for k in model_state.keys()):
model_state = {'clf.%s' % key: value for key, value in model_state.items()}
model.model.load_state_dict(model_state, strict=False)
# Setup attack components
attacker_loss = torch.nn.CrossEntropyLoss()
attacker = breaching.attacks.prepare_attack(server.model, attacker_loss, cfg.attack, setup)
breaching.utils.overview(server, user, attacker)
# Get server payload
server_payload = server.distribute_payload()
# Create save directory if it doesn't exist
if not os.path.exists(cfg.attack.save_dir):
os.mkdir(cfg.attack.save_dir)
# Load and process data
cus_data = CustomData(
data_dir='./assets/private_samples/',
dataset_name='ImageNet',
number_data_points=cfg.case.user.num_data_points
)
# Compute updates and save ground truth
shared_data, true_user_data = user.compute_local_updates(
server_payload,
custom_data=cus_data.process_data()
)
true_pat = cfg.attack.save_dir + 'a_truth.jpg'
cus_data.save_recover(true_user_data, save_pth=true_pat)
# Perform reconstruction and save results
reconstructed_user_data, stats = attacker.reconstruct(
[server_payload],
[shared_data],
{},
dryrun=cfg.dryrun,
custom=cus_data
)
recon_path__ = cfg.attack.save_dir + 'final_rec.jpg'
cus_data.save_recover(reconstructed_user_data, true_user_data, recon_path__)
if __name__ == '__main__':
# Parse command line arguments
import argparse
parser = argparse.ArgumentParser(description='Image reconstruction using Geminio')
group = parser.add_mutually_exclusive_group(required=True)
group.add_argument('--baseline', action='store_true', help='Run baseline reconstruction')
group.add_argument('--geminio-query', type=str, help='Query for Geminio reconstruction')
args = parser.parse_args()
# Initialize configuration and setup
cfg = breaching.get_config(overrides=["case=geminio_demo", "attack=hfgradinv"])
# Set query-specific results directory
if args.geminio_query:
query_name = args.geminio_query.replace(" ", "_").replace("?", "")
cfg.attack.save_dir = f'./results/{query_name}/'
else:
cfg.attack.save_dir = './results/baseline/'
device = torch.device(f'cuda') if torch.cuda.is_available() else torch.device('cpu')
torch.backends.cudnn.benchmark = cfg.case.impl.benchmark
setup = dict(device=device, dtype=getattr(torch, cfg.case.impl.dtype))
# Run reconstruction
reconstruct_image(cfg, setup, args.geminio_query if args.geminio_query else None)