Skip to content

Commit 8d2446f

Browse files
committed
Merge branch 'release_04' of github.com:ECP-CANDLE/Benchmarks into release_04
2 parents 00bd4f9 + 9815442 commit 8d2446f

File tree

4 files changed

+134
-2
lines changed

4 files changed

+134
-2
lines changed

examples/image-vae/image_vae_baseline_pytorch.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def run(gParams):
9797
test_file = candle.fetch_file(data_url + test_data, subdir='Examples/image_vae')
9898

9999
starting_epoch = 1
100-
total_epochs = None
100+
total_epochs = gParams['epochs']
101101

102102
rng_seed = 42
103103
torch.manual_seed(rng_seed)
@@ -263,7 +263,7 @@ def test(epoch, args):
263263
if total_epochs is None:
264264
trn_rng = itertools.count(start=starting_epoch)
265265
else:
266-
trn_rng = range(starting_epoch, total_epochs)
266+
trn_rng = range(starting_epoch, total_epochs + 1)
267267

268268
for epoch in trn_rng:
269269
for param_group in optimizer.param_groups:

examples/image-vae/image_vae_default_model.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ data_url = 'ftp://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Examples/image_va
33
train_data = 'train.csv'
44
test_data = 'test.csv'
55
workers = 16
6+
epochs = None
67
batch_size = 256
78
grad_clip = 2.0
89
model_path = 'models'
Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import logging
2+
import os
3+
import sys
4+
5+
import numpy as np
6+
import torch
7+
from sklearn.linear_model import LinearRegression
8+
from torchvision.utils import save_image
9+
10+
from model import GeneralVae, PictureDecoder, PictureEncoder
11+
12+
logger = logging.getLogger('cairosvg')
13+
logger.setLevel(logging.CRITICAL)
14+
15+
file_path = os.path.dirname(os.path.realpath(__file__))
16+
lib_path = os.path.abspath(os.path.join(file_path, '..', '..', 'common'))
17+
sys.path.append(lib_path)
18+
19+
import candle
20+
21+
additional_definitions = [
22+
{'name': 'batch_size', 'default': 64, 'type': int,
23+
'help': 'mini-batch size per process (default: 256)'},
24+
{'name': 'output_dir', 'help': 'output files path',
25+
'default': 'samples/'},
26+
{'name': 'checkpoint', 'type': str,
27+
'help': 'saved model to sample from'},
28+
{'name': 'num_samples', 'type': int, 'default': 64, 'help': 'number of samples to draw'},
29+
{'name': 'image', 'type': candle.str2bool, 'help': 'save images instead of numpy array'}
30+
]
31+
32+
required = ['checkpoint']
33+
34+
35+
class BenchmarkSample(candle.Benchmark):
36+
37+
def set_locals(self):
38+
"""Functionality to set variables specific for the benchmark
39+
- required: set of required parameters for the benchmark.
40+
- additional_definitions: list of dictionaries describing the additional parameters for the
41+
benchmark.
42+
"""
43+
44+
if required is not None:
45+
self.required = set(required)
46+
if additional_definitions is not None:
47+
self.additional_definitions = additional_definitions
48+
49+
50+
def initialize_parameters(default_model='sample_default_model.txt'):
51+
52+
# Build benchmark object
53+
sampleBmk = BenchmarkSample(file_path, default_model, 'pytorch',
54+
prog='sample_baseline',
55+
desc='PyTorch ImageNet')
56+
57+
# Initialize parameters
58+
gParameters = candle.finalize_parameters(sampleBmk)
59+
# logger.info('Params: {}'.format(gParameters))
60+
61+
return gParameters
62+
63+
64+
if __name__ == '__main__':
65+
gParams = initialize_parameters()
66+
args = candle.ArgumentStruct(**gParams)
67+
68+
# args = get_args()
69+
70+
starting_epoch = 1
71+
total_epochs = None
72+
73+
# seed = 42
74+
# torch.manual_seed(seed)
75+
76+
log_interval = 25
77+
LR = 5.0e-4
78+
79+
cuda = True
80+
device = torch.device("cuda" if cuda and torch.cuda.is_available() else "cpu")
81+
82+
encoder = PictureEncoder(rep_size=512)
83+
decoder = PictureDecoder(rep_size=512)
84+
85+
checkpoint = None
86+
if args.checkpoint is not None:
87+
checkpoint = torch.load(args.model_path + '/' + args.checkpoint, map_location='cpu')
88+
print(f"Loading Checkpoint ({args.checkpoint}).")
89+
starting_epoch = checkpoint['epoch'] + 1
90+
encoder.load_state_dict(checkpoint['encoder_state_dict'])
91+
decoder.load_state_dict(checkpoint['decoder_state_dict'])
92+
93+
encoder = encoder.to(device)
94+
decoder = decoder.to(device)
95+
model = GeneralVae(encoder, decoder, rep_size=512).to(device)
96+
97+
def interpolate_points(x, y, sampling):
98+
ln = LinearRegression()
99+
data = np.stack((x, y))
100+
data_train = np.array([0, 1]).reshape(-1, 1)
101+
ln.fit(data_train, data)
102+
103+
return ln.predict(sampling.reshape(-1, 1)).astype(np.float32)
104+
105+
times = int(args.num_samples / args.batch_size)
106+
print(
107+
f"Using batch size {args.batch_size} and sampling {times} times for a total of {args.batch_size * times} samples drawn. Saving {'images' if args.image else 'numpy array'}")
108+
samples = []
109+
for i in range(times):
110+
with torch.no_grad():
111+
sample = torch.randn(args.batch_size, 512).to(device)
112+
sample = model.decode(sample).cpu()
113+
114+
if args.image:
115+
save_image(sample.view(args.batch_size, 3, 256, 256),
116+
args.output_dir + '/sample_' + str(i) + '.png')
117+
else:
118+
samples.append(sample.view(args.batch_size, 3, 256, 256).cpu().numpy())
119+
120+
if not args.image:
121+
samples = np.concatenate(samples, axis=0)
122+
np.save(f"{args.output_dir}/samples.npy", samples)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
[Global_Params]
2+
data_url = 'http://ftp.mcs.anl.gov/pub/candle/public/benchmarks/Examples/image_vae'
3+
test_model = 'model.pt'
4+
num_samples = 64
5+
batch_size = 64
6+
model_path = 'models'
7+
checkpoint = 'epoch_6.pt'
8+
output_dir = 'samples'
9+
image = True

0 commit comments

Comments
 (0)