Skip to content

Commit a71b3f5

Browse files
committed
Pull fixes from develop in examples.
1 parent e818443 commit a71b3f5

8 files changed

+282
-14
lines changed

examples/histogen/extract_code_default_model.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@ size = 256
33
batch_size = 128
44
use_gpus = True
55
ckpt_directory = './'
6+
ckpt_restart = 'checkpoint/vqvae_001.pt'
7+
lmdb_filename = 'lmdb_001'
8+
data_dir = '../../Data/Examples/histogen/svs_pngs/'
69

710

examples/histogen/sample_default_model.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ batch_size = 8
88
use_gpus = True
99

1010
[Monitor_Params]
11-
timeout=3600
12-
11+
timeout = 3600
Lines changed: 252 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,252 @@
1+
import sys
2+
import os
3+
4+
import numpy as np
5+
import torch
6+
from torch import nn, optim
7+
from torch.utils.data import DataLoader
8+
from tqdm import tqdm
9+
from argparse import SUPPRESS
10+
11+
try:
12+
from apex import amp
13+
14+
except ImportError:
15+
amp = None
16+
17+
from dataset import LMDBDataset
18+
from pixelsnail import PixelSNAIL
19+
from scheduler import CycleScheduler
20+
21+
file_path = os.path.dirname(os.path.realpath(__file__))
22+
lib_path = os.path.abspath(os.path.join(file_path, '..'))
23+
sys.path.append(lib_path)
24+
lib_path2 = os.path.abspath(os.path.join(file_path, '..', '..', 'common'))
25+
sys.path.append(lib_path2)
26+
27+
28+
import candle
29+
30+
additional_definitions = [
31+
{'name': 'sched_mode',
32+
'type': str,
33+
'default': None,
34+
'help': 'Mode of learning rate scheduler'},
35+
{'name': 'lmdb_filename',
36+
'type': str,
37+
'default': SUPPRESS,
38+
'help': 'lmdb dataset path'},
39+
{'name': 'amp',
40+
'type': str,
41+
'default': 'O0',
42+
'help': ''},
43+
{'name': 'hier',
44+
'type': str,
45+
'default': 'top',
46+
'help': ''},
47+
{'name': 'channel',
48+
'type': int,
49+
'default': 256,
50+
'help': ''},
51+
{'name': 'n_res_block',
52+
'type': int,
53+
'default': 4,
54+
'help': ''},
55+
{'name': 'n_res_channel',
56+
'type': int,
57+
'default': 256,
58+
'help': ''},
59+
{'name': 'n_out_res_block',
60+
'type': int,
61+
'default': 0,
62+
'help': ''},
63+
{'name': 'n_cond_res_block',
64+
'type': int,
65+
'default': 3,
66+
'help': ''},
67+
{'name': 'ckpt_restart',
68+
'type': str,
69+
'default': None,
70+
'help': 'Checkpoint to restart from'},
71+
]
72+
73+
required = [
74+
'batch_size',
75+
'epochs',
76+
'hier',
77+
'learning_rate',
78+
'channel',
79+
'n_res_block',
80+
'n_res_channel',
81+
'n_out_res_block',
82+
'n_cond_res_block',
83+
'dropout',
84+
'amp',
85+
'sched_mode',
86+
'lmdb_filename',
87+
]
88+
89+
90+
class TrPxSnBk(candle.Benchmark):
91+
92+
def set_locals(self):
93+
"""Functionality to set variables specific for the benchmark
94+
- required: set of required parameters for the benchmark.
95+
- additional_definitions: list of dictionaries describing the additional parameters for the
96+
benchmark.
97+
"""
98+
99+
if required is not None:
100+
self.required = set(required)
101+
if additional_definitions is not None:
102+
self.additional_definitions = additional_definitions
103+
104+
105+
def initialize_parameters(default_model='train_pixelsnail_default_model.txt'):
106+
107+
# Build benchmark object
108+
trpsn = TrPxSnBk(file_path, default_model, 'pytorch',
109+
prog='train_pixelsnail_baseline',
110+
desc='Histology train pixelsnail - Examples')
111+
112+
print("Created sample benchmark")
113+
114+
# Initialize parameters
115+
gParameters = candle.finalize_parameters(trpsn)
116+
print("Parameters initialized")
117+
118+
return gParameters
119+
120+
121+
def train(args, epoch, loader, model, optimizer, scheduler, device):
122+
loader = tqdm(loader)
123+
124+
criterion = nn.CrossEntropyLoss()
125+
126+
for i, (top, bottom, label) in enumerate(loader):
127+
model.zero_grad()
128+
129+
top = top.to(device)
130+
131+
if args.hier == 'top':
132+
target = top
133+
out, _ = model(top)
134+
135+
elif args.hier == 'bottom':
136+
bottom = bottom.to(device)
137+
target = bottom
138+
out, _ = model(bottom, condition=top)
139+
140+
loss = criterion(out, target)
141+
loss.backward()
142+
143+
if scheduler is not None:
144+
scheduler.step()
145+
optimizer.step()
146+
147+
_, pred = out.max(1)
148+
correct = (pred == target).float()
149+
accuracy = correct.sum() / target.numel()
150+
151+
lr = optimizer.param_groups[0]['lr']
152+
153+
loader.set_description(
154+
(
155+
f'epoch: {epoch + 1}; loss: {loss.item():.5f}; '
156+
f'acc: {accuracy:.5f}; lr: {lr:.5f}'
157+
)
158+
)
159+
160+
161+
class PixelTransform:
162+
def __init__(self):
163+
pass
164+
165+
def __call__(self, input):
166+
ar = np.array(input)
167+
168+
return torch.from_numpy(ar).long()
169+
170+
171+
def run(params):
172+
173+
args = candle.ArgumentStruct(**params)
174+
# Configure GPUs
175+
ndevices = torch.cuda.device_count()
176+
if ndevices < 1:
177+
raise Exception('No CUDA gpus available')
178+
179+
device = 'cuda'
180+
181+
dataset = LMDBDataset(args.lmdb_filename)
182+
loader = DataLoader(
183+
dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True
184+
)
185+
186+
ckpt = {}
187+
188+
if args.ckpt_restart is not None:
189+
ckpt = torch.load(args.ckpt_restart)
190+
args = ckpt['args']
191+
192+
if args.hier == 'top':
193+
model = PixelSNAIL(
194+
[32, 32],
195+
512,
196+
args.channel,
197+
5,
198+
4,
199+
args.n_res_block,
200+
args.n_res_channel,
201+
dropout=args.dropout,
202+
n_out_res_block=args.n_out_res_block,
203+
)
204+
205+
elif args.hier == 'bottom':
206+
model = PixelSNAIL(
207+
[64, 64],
208+
512,
209+
args.channel,
210+
5,
211+
4,
212+
args.n_res_block,
213+
args.n_res_channel,
214+
attention=False,
215+
dropout=args.dropout,
216+
n_cond_res_block=args.n_cond_res_block,
217+
cond_res_channel=args.n_res_channel,
218+
)
219+
220+
if 'model' in ckpt:
221+
model.load_state_dict(ckpt['model'])
222+
223+
model = model.to(device)
224+
optimizer = optim.Adam(model.parameters(), lr=args.learning_rate)
225+
226+
if amp is not None:
227+
model, optimizer = amp.initialize(model, optimizer, opt_level=args.amp)
228+
229+
model = nn.DataParallel(model)
230+
model = model.to(device)
231+
232+
scheduler = None
233+
if args.sched_mode == 'cycle':
234+
scheduler = CycleScheduler(
235+
optimizer, args.learning_rate, n_iter=len(loader) * args.epochs, momentum=None
236+
)
237+
238+
for i in range(args.epochs):
239+
train(args, i, loader, model, optimizer, scheduler, device)
240+
torch.save(
241+
{'model': model.module.state_dict(), 'args': args},
242+
f'{args.ckpt_directory}/checkpoint/pixelsnail_{args.hier}_{str(i + 1).zfill(3)}.pt',
243+
)
244+
245+
246+
def main():
247+
params = initialize_parameters()
248+
run(params)
249+
250+
251+
if __name__ == '__main__':
252+
main()

examples/histogen/train_pixelsnail_default_model.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[Global_Params]
2+
lmdb_filename = 'lmdb_001'
23
epochs = 420
34
batch_size = 32
45
learning_rate = 3e-4
@@ -12,5 +13,3 @@ dropout = 0.1
1213
amp = 'O0'
1314
use_gpus = True
1415
ckpt_directory = './'
15-
16-

examples/histogen/train_vqvae_baseline_pytorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
'type': str,
4444
'default': SUPPRESS,
4545
'help': 'dataset path'},
46-
{'name': 'size',
46+
{'name': 'image_size',
4747
'type': int,
4848
'default': 256,
4949
'help': 'Image size to use'},
@@ -55,7 +55,7 @@
5555
'epochs',
5656
'learning_rate',
5757
'sched_mode',
58-
'size',
58+
'image_size',
5959
]
6060

6161

@@ -168,8 +168,8 @@ def config_and_train(args):
168168

169169
transform = transforms.Compose(
170170
[
171-
transforms.Resize(args.size),
172-
transforms.CenterCrop(args.size),
171+
transforms.Resize(args.image_size),
172+
transforms.CenterCrop(args.image_size),
173173
transforms.ToTensor(),
174174
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
175175
]
@@ -212,7 +212,7 @@ def fetch_data(params):
212212
if params['data_dir'] is None:
213213
params['data_dir'] = candle.fetch_file(data_url + params['train_data'], subdir='Examples/histogen')
214214
else:
215-
tempfile = candle.fetch_file(data_url + params['train_data'], cache_subdir='Examples/histogen')
215+
tempfile = candle.fetch_file(data_url + params['train_data'], subdir='Examples/histogen')
216216
params['data_dir'] = os.path.join(os.path.dirname(tempfile), params['data_dir'])
217217

218218

examples/histogen/train_vqvae_default_model.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ data_dir = 'svs_pngs'
55
epochs = 560
66
learning_rate = 3e-4
77
batch_size = 128
8-
use_gpus = True
8+
n_gpu_per_machine = 1
99
ckpt_directory = './'
10-
11-
10+
image_size = 256

examples/rnngen/infer_rnngen_baseline_pytorch.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from model.model import CharRNN
1212
from model.vocab import START_CHAR, END_CHAR
13-
from train import get_vocab_from_file
13+
from model.vocab import get_vocab_from_file
1414

1515
file_path = os.path.dirname(os.path.realpath(__file__))
1616
lib_path = os.path.abspath(os.path.join(file_path, '..'))
@@ -70,6 +70,7 @@
7070
'output',
7171
'input',
7272
'nsamples',
73+
'model',
7374
]
7475

7576

@@ -175,6 +176,18 @@ def run(params):
175176
print("Note: This script is very picky. Please check device output to see where this is running. ")
176177
args = candle.ArgumentStruct(**params)
177178

179+
data_url = args.data_url
180+
181+
if args.model == 'ft_goodperforming_model.pt':
182+
file = 'pilot1/ft_goodperforming_model.pt'
183+
elif args.model == 'ft_poorperforming_model.pt':
184+
file = 'pilot1/ft_poorperforming_model.pt'
185+
else: # Corresponding to args.model == 'autosave.model.pt':
186+
file = 'mosesrun/autosave.model.pt'
187+
188+
print('Recovering trained model')
189+
trained = candle.fetch_file(data_url + file, subdir='examples/rnngen')
190+
178191
# Configure GPU
179192
if args.use_gpus and torch.cuda.is_available():
180193
device = 'cuda'
@@ -188,10 +201,12 @@ def run(params):
188201
model = CharRNN(len(vocab), len(vocab), max_len=args.maxlen).to(device)
189202
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
190203

191-
pt = torch.load(args.logdir + "/" + args.model, map_location=device)
204+
print("Loading trained model.")
205+
pt = torch.load(trained, map_location=device)
192206
model.load_state_dict(pt['state_dict'])
193207
optimizer.load_state_dict(pt['optim_state_dict'])
194208

209+
print("Applying to loaded data")
195210
total_sampled = 0
196211
total_valid = 0
197212
total_unqiue = 0

examples/rnngen/infer_rnngen_default_model.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[Global_Params]
2+
data_url = 'ftp://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Examples/rnngen/'
23
input = 'mosesrun/'
34
logdir = 'mosesrun/'
45
output = 'samples.txt'

0 commit comments

Comments
 (0)