Skip to content

Commit 731566b

Browse files
committed
Merge branch 'release_04' of https://github.com/ECP-CANDLE/Benchmarks into release_04
2 parents 6278290 + ee33263 commit 731566b

9 files changed

+34
-34
lines changed

Pilot3/P3B5/test.py

Lines changed: 0 additions & 16 deletions
This file was deleted.

examples/histogen/extract_code_default_model.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,8 @@ size = 256
33
batch_size = 128
44
use_gpus = True
55
ckpt_directory = './'
6+
ckpt_restart = 'checkpoint/vqvae_001.pt'
7+
lmdb_filename = 'lmdb_001'
8+
data_dir = '../../Data/Examples/histogen/svs_pngs/'
69

710

examples/histogen/sample_default_model.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,5 +8,4 @@ batch_size = 8
88
use_gpus = True
99

1010
[Monitor_Params]
11-
timeout=3600
12-
11+
timeout = 3600

examples/histogen/train_pixelsnail_baseline_pytroch.py renamed to examples/histogen/train_pixelsnail_baseline_pytorch.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@
3232
'type': str,
3333
'default': None,
3434
'help': 'Mode of learning rate scheduler'},
35-
{'name': 'data_dir',
35+
{'name': 'lmdb_filename',
3636
'type': str,
3737
'default': SUPPRESS,
38-
'help': 'dataset path'},
38+
'help': 'lmdb dataset path'},
3939
{'name': 'amp',
4040
'type': str,
4141
'default': 'O0',
@@ -83,7 +83,7 @@
8383
'dropout',
8484
'amp',
8585
'sched_mode',
86-
'data_dir',
86+
'lmdb_filename',
8787
]
8888

8989

@@ -178,7 +178,7 @@ def run(params):
178178

179179
device = 'cuda'
180180

181-
dataset = LMDBDataset(args.data_dir)
181+
dataset = LMDBDataset(args.lmdb_filename)
182182
loader = DataLoader(
183183
dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True
184184
)

examples/histogen/train_pixelsnail_default_model.txt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[Global_Params]
2+
lmdb_filename = 'lmdb_001'
23
epochs = 420
34
batch_size = 32
45
learning_rate = 3e-4
@@ -12,5 +13,3 @@ dropout = 0.1
1213
amp = 'O0'
1314
use_gpus = True
1415
ckpt_directory = './'
15-
16-

examples/histogen/train_vqvae_baseline_pytorch.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
'type': str,
4444
'default': SUPPRESS,
4545
'help': 'dataset path'},
46-
{'name': 'size',
46+
{'name': 'image_size',
4747
'type': int,
4848
'default': 256,
4949
'help': 'Image size to use'},
@@ -55,7 +55,7 @@
5555
'epochs',
5656
'learning_rate',
5757
'sched_mode',
58-
'size',
58+
'image_size',
5959
]
6060

6161

@@ -168,8 +168,8 @@ def config_and_train(args):
168168

169169
transform = transforms.Compose(
170170
[
171-
transforms.Resize(args.size),
172-
transforms.CenterCrop(args.size),
171+
transforms.Resize(args.image_size),
172+
transforms.CenterCrop(args.image_size),
173173
transforms.ToTensor(),
174174
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
175175
]
@@ -212,7 +212,7 @@ def fetch_data(params):
212212
if params['data_dir'] is None:
213213
params['data_dir'] = candle.fetch_file(data_url + params['train_data'], subdir='Examples/histogen')
214214
else:
215-
tempfile = candle.fetch_file(data_url + params['train_data'], cache_subdir='Examples/histogen')
215+
tempfile = candle.fetch_file(data_url + params['train_data'], subdir='Examples/histogen')
216216
params['data_dir'] = os.path.join(os.path.dirname(tempfile), params['data_dir'])
217217

218218

examples/histogen/train_vqvae_default_model.txt

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ data_dir = 'svs_pngs'
55
epochs = 560
66
learning_rate = 3e-4
77
batch_size = 128
8-
use_gpus = True
8+
n_gpu_per_machine = 1
99
ckpt_directory = './'
10-
11-
10+
image_size = 256

examples/rnngen/infer_rnngen_baseline_pytorch.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
from model.model import CharRNN
1212
from model.vocab import START_CHAR, END_CHAR
13-
from train import get_vocab_from_file
13+
from model.vocab import get_vocab_from_file
1414

1515
file_path = os.path.dirname(os.path.realpath(__file__))
1616
lib_path = os.path.abspath(os.path.join(file_path, '..'))
@@ -70,6 +70,7 @@
7070
'output',
7171
'input',
7272
'nsamples',
73+
'model',
7374
]
7475

7576

@@ -175,6 +176,18 @@ def run(params):
175176
print("Note: This script is very picky. Please check device output to see where this is running. ")
176177
args = candle.ArgumentStruct(**params)
177178

179+
data_url = args.data_url
180+
181+
if args.model == 'ft_goodperforming_model.pt':
182+
file = 'pilot1/ft_goodperforming_model.pt'
183+
elif args.model == 'ft_poorperforming_model.pt':
184+
file = 'pilot1/ft_poorperforming_model.pt'
185+
else: # Corresponding to args.model == 'autosave.model.pt':
186+
file = 'mosesrun/autosave.model.pt'
187+
188+
print('Recovering trained model')
189+
trained = candle.fetch_file(data_url + file, subdir='examples/rnngen')
190+
178191
# Configure GPU
179192
if args.use_gpus and torch.cuda.is_available():
180193
device = 'cuda'
@@ -188,10 +201,12 @@ def run(params):
188201
model = CharRNN(len(vocab), len(vocab), max_len=args.maxlen).to(device)
189202
optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
190203

191-
pt = torch.load(args.logdir + "/" + args.model, map_location=device)
204+
print("Loading trained model.")
205+
pt = torch.load(trained, map_location=device)
192206
model.load_state_dict(pt['state_dict'])
193207
optimizer.load_state_dict(pt['optim_state_dict'])
194208

209+
print("Applying to loaded data")
195210
total_sampled = 0
196211
total_valid = 0
197212
total_unqiue = 0

examples/rnngen/infer_rnngen_default_model.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[Global_Params]
2+
data_url = 'ftp://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Examples/rnngen/'
23
input = 'mosesrun/'
34
logdir = 'mosesrun/'
45
output = 'samples.txt'

0 commit comments

Comments
 (0)