forked from sarahmu/Histopathology-Imaging
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_gan.py
More file actions
55 lines (52 loc) · 2.64 KB
/
train_gan.py
File metadata and controls
55 lines (52 loc) · 2.64 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import sys, getopt
from gan import train_gan
if __name__ == '__main__':
try:
opts, args = getopt.getopt(sys.argv[1:],
't:v:o:d:g:b:r:n:l:s:eac:i:',
['train_data_dir=', 'val_data_dir=', 'output_dir=',
'D_lr=', 'G_lr=', 'beta1=', 'reg=', 'num_epochs=', 'loss=', 'batch_size=',
'eval_val', 'save_eval_img', 'device=', 'num_eval_img='])
except getopt.GetoptError:
print ('train_gan.py -t <train_data_dir> -v <val_data_dir> -o <output_dir> -d <D_lr> -g <G_lr> -b <beta1> -r <reg> -n <num_epochs> -l <loss> -s <batch_size> -e -a -c <device>')
sys.exit(2)
for opt, arg in opts:
if opt in ('-t','--train_data_dir'):
train_data_dir = arg
elif opt in ('-v', '--val_data_dir'):
val_data_dir = arg
elif opt in ('-o', '--output_dir'):
output_dir = arg
elif opt in ('-d', '--D_lr'):
D_lr = float(arg)
elif opt in ('-g', '--G_lr'):
G_lr = float(arg)
elif opt in ('-b', '--beta1'):
beta1 = float(arg)
elif opt in ('-r', '--reg'):
reg = float(arg)
elif opt in ('-n', '--num_epochs'):
num_epochs = int(arg)
elif opt in ('-l', '--loss'):
loss = arg
elif opt in ('-s', '--batch_size'):
batch_size = int(arg)
elif opt in ('-e', '--eval_val'):
eval_val = True
elif opt in ('-a', '--save_eval_img'):
save_eval_img = True
elif opt in ('-c', '--device'):
device = arg
elif opt in ('-i', '--num_eval_img'):
num_eval_img = int(arg)
print('Running train_gan.py with parameters: --train_data_dir=' + train_data_dir + \
', --val_data_dir=' + val_data_dir + ', --output_dir=' + output_dir + \
', --D_lr=' + str(D_lr) + ', --G_lr=' + str(G_lr) + \
', --beta1=' + str(beta1) + ', --reg=' + str(reg) + ', --num_epochs=' + str(num_epochs) + \
', --loss=' + loss + ', --batch_size=' + str(batch_size) + ', --eval_val=' + str(eval_val) + \
', --save_eval_img=' + str(save_eval_img) + ', --device=' + device + ', --num_eval_img=' + str(num_eval_img))
train_gan(train_data_dir=train_data_dir, val_data_dir=val_data_dir, output_dir=output_dir,
D_lr=D_lr, G_lr=G_lr, beta1=beta1, reg=reg, num_epochs=num_epochs, loss=loss,
batch_size=batch_size, eval_val=eval_val, save_eval_img=save_eval_img, device=device,
num_eval_img=num_eval_img)
print('Finished training GAN!')