Skip to content

Commit df2c240

Browse files
committed
Added UpSampling3D ONNX tests
1 parent 49fe323 commit df2c240

File tree

4 files changed

+240
-1
lines changed

4 files changed

+240
-1
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import argparse
2+
3+
import numpy as np
4+
from tensorflow.keras.models import Sequential
5+
from tensorflow.keras.layers import Input, Conv3D, UpSampling3D, MaxPooling3D
6+
import keras2onnx
7+
8+
# Training settings
9+
parser = argparse.ArgumentParser(description='Keras Conv3D+Upsampling encoder decoder with synthetic data Example')
10+
parser.add_argument('--batch-size', type=int, default=2, metavar='N',
11+
help='input batch size for training (default: 2)')
12+
parser.add_argument('--epochs', type=int, default=5, metavar='N',
13+
help='number of epochs to train (default: 5)')
14+
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
15+
help='learning rate (default: 0.01)')
16+
parser.add_argument('--no-cuda', action='store_true', default=False,
17+
help='disables CUDA training')
18+
parser.add_argument('--seed', type=int, default=1, metavar='S',
19+
help='random seed (default: 1)')
20+
parser.add_argument('--output-path', type=str, default="onnx_models/upsample3D_enc_dec_synthetic.onnx",
21+
help='Output path to store the onnx file')
22+
parser.add_argument('--output-metric', type=str, default="",
23+
help='Output file path to store the metric value obtained in test set')
24+
args = parser.parse_args()
25+
26+
# Create synthetic data
27+
n_samples = 6
28+
# Shape: (n_samples, ch=3, depth=16, height=16, width=16)
29+
x_train = np.linspace(0, 1, n_samples*3*16*16*16)
30+
x_train = x_train.reshape((n_samples, 3, 16, 16, 16)).astype(np.float32)
31+
# (B, C, D, H, W) -> (B, D, H, W, C)
32+
x_train = np.transpose(x_train, (0, 2, 3, 4, 1)) # Set channel last
33+
34+
print("Train data shape:", x_train.shape)
35+
36+
# Definer encoder
37+
model = Sequential()
38+
model.add(Input(shape=(16, 16, 16, 3)))
39+
# Encoder
40+
model.add(Conv3D(32, 3, padding="same", activation="relu"))
41+
model.add(MaxPooling3D(2, 2))
42+
model.add(Conv3D(64, 3, padding="same", activation="relu"))
43+
model.add(MaxPooling3D(2, 2))
44+
# Decoder
45+
model.add(Conv3D(64, 3, padding="same", activation="relu"))
46+
model.add(UpSampling3D((2, 2, 2)))
47+
model.add(Conv3D(32, 3, padding="same", activation="relu"))
48+
model.add(UpSampling3D((2, 2, 2)))
49+
model.add(Conv3D(3, 1, padding="valid", activation="sigmoid"))
50+
51+
model.compile(loss='mse',
52+
optimizer="adam",
53+
metrics=[])
54+
55+
model.summary()
56+
57+
# Training
58+
model.fit(x_train, x_train, batch_size=args.batch_size, epochs=args.epochs)
59+
60+
# Evaluation
61+
eval_loss = model.evaluate(x_train, x_train)
62+
print("Evaluation result: Loss:", eval_loss)
63+
64+
# In case of providing output metric file, store the test mse value
65+
if args.output_metric != "":
66+
with open(args.output_metric, 'w') as ofile:
67+
ofile.write(str(eval_loss))
68+
69+
# Convert to ONNX
70+
onnx_model = keras2onnx.convert_keras(model, "upsample3D_synthetic", debug_mode=1)
71+
# Save ONNX to file
72+
keras2onnx.save_model(onnx_model, args.output_path)

scripts/tests/py_onnx/pytorch/export_scripts/convT3D_enc_dec_synthetic_pytorch_export.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def test(model, device, test_loader):
110110
def main():
111111
# Training settings
112112
parser = argparse.ArgumentParser(
113-
description='PyTorch ConvT2D encoder-decoder MNIST Example')
113+
description='PyTorch ConvT3D encoder-decoder with synthetic data example')
114114
parser.add_argument('--batch-size', type=int, default=2, metavar='N',
115115
help='input batch size for training (default: 2)')
116116
parser.add_argument('--epochs', type=int, default=5, metavar='N',
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
from __future__ import print_function
2+
import argparse
3+
4+
import torch
5+
import torch.nn as nn
6+
import torch.nn.functional as F
7+
import torch.optim as optim
8+
import numpy as np
9+
10+
11+
class Net(nn.Module):
12+
def __init__(self):
13+
super(Net, self).__init__()
14+
# Encoder
15+
self.encoder = nn.Sequential(
16+
nn.Conv3d(3, 32, 3, stride=1, padding=1),
17+
nn.ReLU(),
18+
nn.MaxPool3d(2, 2),
19+
nn.Conv3d(32, 64, 3, stride=1, padding=1),
20+
nn.ReLU(),
21+
nn.MaxPool3d(2, 2),
22+
)
23+
# Decoder
24+
self.decoder = nn.Sequential(
25+
nn.Conv3d(64, 64, 3, stride=1, padding=1),
26+
nn.ReLU(),
27+
nn.Upsample(scale_factor=2),
28+
nn.Conv3d(64, 32, 3, stride=1, padding=1),
29+
nn.ReLU(),
30+
nn.Upsample(scale_factor=2),
31+
nn.Conv3d(32, 3, 1, stride=1, padding=0),
32+
nn.Sigmoid()
33+
)
34+
35+
def forward(self, x):
36+
x_enc = self.encoder(x)
37+
out = self.decoder(x_enc)
38+
return out
39+
40+
41+
# Prepare data loader
42+
class Dummy_datagen:
43+
def __init__(self, batch_size=2, n_samples=6):
44+
# Shape: (n_samples=n_samples, ch=3, depth=16, height=16, width=16)
45+
self.samples = np.linspace(0, 1, n_samples*3*16*16*16).reshape((n_samples, 3, 16, 16, 16)).astype(np.float32)
46+
self.curr_idx = 0 # Current index of the batch
47+
self.bs = batch_size
48+
49+
def __iter__(self):
50+
return self
51+
52+
def __len__(self):
53+
return int(self.samples.shape[0] / self.bs)
54+
55+
def __next__(self):
56+
target = self.curr_idx
57+
self.curr_idx += self.bs
58+
if target <= self.samples.shape[0]-self.bs:
59+
return self.samples[target:target+self.bs]
60+
raise StopIteration
61+
62+
def reset(self):
63+
'''Reset the iterator'''
64+
self.curr_idx = 0
65+
66+
67+
def train(args, model, device, train_loader, optimizer, epoch):
68+
model.train()
69+
loss_acc = 0
70+
current_samples = 0
71+
for batch_idx, data in enumerate(train_loader):
72+
data = torch.from_numpy(data)
73+
data = data.to(device)
74+
b, c, d, h, w = data.size()
75+
data_el_size = c * d * h * w
76+
optimizer.zero_grad()
77+
output = model(data)
78+
loss = F.mse_loss(output, data, reduction='sum')
79+
loss.backward()
80+
loss_acc += loss.item() / data_el_size
81+
current_samples += data.size(0)
82+
optimizer.step()
83+
if batch_idx % 10 == 0:
84+
print('\rTrain Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
85+
epoch, batch_idx * len(data), len(train_loader.samples),
86+
100. * batch_idx / len(train_loader), loss_acc / current_samples))
87+
88+
89+
def test(model, device, test_loader):
90+
model.eval()
91+
test_loss = 0
92+
current_samples = 0
93+
with torch.no_grad():
94+
for data in test_loader:
95+
data = torch.from_numpy(data)
96+
data = data.to(device)
97+
output = model(data)
98+
b, c, d, h, w = data.size()
99+
data_el_size = c * d * h * w
100+
test_loss += F.mse_loss(output, data, reduction='sum').item() / data_el_size
101+
current_samples += data.size(0)
102+
103+
test_loss = test_loss / current_samples
104+
print(f'\nTest set: Average loss: {test_loss:.4f}\n')
105+
106+
return test_loss
107+
108+
109+
def main():
110+
# Training settings
111+
parser = argparse.ArgumentParser(
112+
description='PyTorch Conv3D+Upsample encoder-decoder with synthetic data example')
113+
parser.add_argument('--batch-size', type=int, default=2, metavar='N',
114+
help='input batch size for training (default: 2)')
115+
parser.add_argument('--epochs', type=int, default=5, metavar='N',
116+
help='number of epochs to train (default: 5)')
117+
parser.add_argument('--lr', type=float, default=0.01, metavar='LR',
118+
help='learning rate (default: 0.01)')
119+
parser.add_argument('--no-cuda', action='store_true', default=False,
120+
help='disables CUDA training')
121+
parser.add_argument('--seed', type=int, default=1, metavar='S',
122+
help='random seed (default: 1)')
123+
parser.add_argument('--output-path', type=str, default="onnx_models/upsample3D_enc_dec_synthetic.onnx",
124+
help='Output path to store the onnx file')
125+
parser.add_argument('--output-metric', type=str, default="",
126+
help='Output file path to store the metric value obtained in test set')
127+
args = parser.parse_args()
128+
use_cuda = not args.no_cuda and torch.cuda.is_available()
129+
130+
torch.manual_seed(args.seed)
131+
132+
device = torch.device("cuda" if use_cuda else "cpu")
133+
134+
model = Net().to(device)
135+
optimizer = optim.Adam(model.parameters(), lr=args.lr)
136+
137+
# Create data generators
138+
train_loader = Dummy_datagen(args.batch_size)
139+
test_loader = Dummy_datagen(args.batch_size)
140+
141+
# Train
142+
for epoch in range(1, args.epochs + 1):
143+
train(args, model, device, train_loader, optimizer, epoch)
144+
test_loss = test(model, device, test_loader)
145+
train_loader.reset()
146+
test_loader.reset()
147+
148+
# In case of providing output metric file, store the test accuracy value
149+
if args.output_metric != "":
150+
with open(args.output_metric, 'w') as ofile:
151+
ofile.write(str(test_loss))
152+
153+
# Save to ONNX file
154+
dummy_input = torch.randn(args.batch_size, 3, 16, 16, 16, device=device)
155+
torch.onnx._export(model, dummy_input, args.output_path, keep_initializers_as_inputs=True)
156+
157+
158+
if __name__ == '__main__':
159+
main()

scripts/tests/run_onnx_tests.sh

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ scripts_to_run+=("EDDL_to_EDDL_conv3D;test_onnx_conv3D;test_onnx_conv3D")
108108
scripts_to_run+=("EDDL_to_EDDL_convT2D;test_onnx_convT2D;test_onnx_convT2D")
109109
scripts_to_run+=("EDDL_to_EDDL_convT3D;test_onnx_convT3D;test_onnx_convT3D")
110110
scripts_to_run+=("EDDL_to_EDDL_upsample2D;test_onnx_upsample2D;test_onnx_upsample2D")
111+
scripts_to_run+=("EDDL_to_EDDL_upsample3D;test_onnx_upsample3D;test_onnx_upsample3D")
111112
scripts_to_run+=("EDDL_to_EDDL_GRU_imdb;test_onnx_gru_imdb;test_onnx_gru_imdb")
112113
scripts_to_run+=("EDDL_to_EDDL_LSTM_imdb;test_onnx_lstm_imdb;test_onnx_lstm_imdb")
113114
scripts_to_run+=("EDDL_to_EDDL_GRU_mnist;test_onnx_gru_mnist;test_onnx_gru_mnist")
@@ -123,6 +124,7 @@ scripts_to_run+=("EDDL_to_EDDL_conv3D_CPU;test_onnx_conv3D_cpu;test_onnx_conv3D,
123124
#scripts_to_run+=("EDDL_to_EDDL_convT2D_CPU;test_onnx_convT2D_cpu;test_onnx_convT2D,--cpu") ConvT2D not available in CPU
124125
#scripts_to_run+=("EDDL_to_EDDL_convT3D_CPU;test_onnx_convT3D_cpu;test_onnx_convT3D,--cpu") ConvT3D not available in CPU
125126
scripts_to_run+=("EDDL_to_EDDL_upsample2D_CPU;test_onnx_upsample2D_cpu;test_onnx_upsample2D,--cpu")
127+
scripts_to_run+=("EDDL_to_EDDL_upsample3D_CPU;test_onnx_upsample3D_cpu;test_onnx_upsample3D,--cpu")
126128
scripts_to_run+=("EDDL_to_EDDL_GRU_imdb_CPU;test_onnx_gru_imdb_cpu;test_onnx_gru_imdb,--cpu")
127129
scripts_to_run+=("EDDL_to_EDDL_LSTM_imdb_CPU;test_onnx_lstm_imdb_cpu;test_onnx_lstm_imdb,--cpu")
128130
scripts_to_run+=("EDDL_to_EDDL_GRU_mnist_CPU;test_onnx_gru_mnist_cpu;test_onnx_gru_mnist,--cpu")
@@ -200,6 +202,7 @@ then
200202
eddl2onnxrt+=("EDDL_to_ONNXRT_convT2D;test_onnx_convT2D;onnxruntime_enc_dec_mnist.py")
201203
eddl2onnxrt+=("EDDL_to_ONNXRT_convT3D;test_onnx_convT3D;onnxruntime_enc_dec_synthetic3D.py")
202204
eddl2onnxrt+=("EDDL_to_ONNXRT_upsample2D;test_onnx_upsample2D;onnxruntime_enc_dec_mnist.py")
205+
eddl2onnxrt+=("EDDL_to_ONNXRT_upsample3D;test_onnx_upsample3D;onnxruntime_enc_dec_synthetic3D.py")
203206
eddl2onnxrt+=("EDDL_to_ONNXRT_GRU_imdb;test_onnx_gru_imdb;onnxruntime_imdb_keras.py,--unsqueeze-input")
204207
eddl2onnxrt+=("EDDL_to_ONNXRT_LSTM_imdb;test_onnx_lstm_imdb;onnxruntime_imdb_keras.py,--unsqueeze-input")
205208
eddl2onnxrt+=("EDDL_to_ONNXRT_LSTM_enc_dec;test_onnx_lstm_enc_dec;onnxruntime_recurrent_enc_dec_mnist.py")
@@ -212,6 +215,7 @@ then
212215
#eddl2onnxrt+=("EDDL_to_ONNXRT_convT2D_CPU;test_onnx_convT2D_cpu;onnxruntime_enc_dec_mnist.py") ConvT2D not available in CPU
213216
#eddl2onnxrt+=("EDDL_to_ONNXRT_convT3D_CPU;test_onnx_convT3D_cpu;onnxruntime_enc_dec_synthetic3D.py") ConvT3D not available in CPU
214217
eddl2onnxrt+=("EDDL_to_ONNXRT_upsample2D_CPU;test_onnx_upsample2D_cpu;onnxruntime_enc_dec_mnist.py")
218+
eddl2onnxrt+=("EDDL_to_ONNXRT_upsample3D_CPU;test_onnx_upsample3D_cpu;onnxruntime_enc_dec_synthetic3D.py")
215219
eddl2onnxrt+=("EDDL_to_ONNXRT_GRU_imdb_CPU;test_onnx_gru_imdb_cpu;onnxruntime_imdb_keras.py,--unsqueeze-input")
216220
eddl2onnxrt+=("EDDL_to_ONNXRT_LSTM_imdb_CPU;test_onnx_lstm_imdb_cpu;onnxruntime_imdb_keras.py,--unsqueeze-input")
217221
eddl2onnxrt+=("EDDL_to_ONNXRT_LSTM_enc_dec_CPU;test_onnx_lstm_enc_dec_cpu;onnxruntime_recurrent_enc_dec_mnist.py")
@@ -331,6 +335,7 @@ then
331335
pytorch2eddl+=("Pytorch_to_EDDL_convT2D;test_onnx_pytorch_convT2D;export_scripts/convT2D_enc_dec_mnist_pytorch_export.py;test_onnx_convT2D,--import")
332336
pytorch2eddl+=("Pytorch_to_EDDL_convT3D;test_onnx_pytorch_convT3D;export_scripts/convT3D_enc_dec_synthetic_pytorch_export.py;test_onnx_convT3D,--import")
333337
pytorch2eddl+=("Pytorch_to_EDDL_upsample2D;test_onnx_pytorch_upsample2D;export_scripts/upsample2D_enc_dec_mnist_pytorch_export.py;test_onnx_upsample2D,--import")
338+
pytorch2eddl+=("Pytorch_to_EDDL_upsample3D;test_onnx_pytorch_upsample3D;export_scripts/upsample3D_enc_dec_synthetic_pytorch_export.py;test_onnx_upsample3D,--import")
334339
pytorch2eddl+=("Pytorch_to_EDDL_LSTM_IMDB;test_onnx_pytorch_LSTM_imdb;export_scripts/lstm_pytorch_export.py;test_onnx_lstm_imdb,--import")
335340
pytorch2eddl+=("Pytorch_to_EDDL_GRU_IMDB;test_onnx_pytorch_GRU_imdb;export_scripts/gru_pytorch_export.py;test_onnx_gru_imdb,--import")
336341
pytorch2eddl+=("Pytorch_to_EDDL_LSTM_MNIST;test_onnx_pytorch_LSTM_mnist;export_scripts/lstm_mnist_pytorch_export.py;test_onnx_lstm_mnist,--import")
@@ -347,6 +352,7 @@ then
347352
#pytorch2eddl+=("Pytorch_to_EDDL_convT2D_CPU;test_onnx_pytorch_convT2D;none;test_onnx_convT2D,--import,--cpu") ConvT2D not available in CPU
348353
#pytorch2eddl+=("Pytorch_to_EDDL_convT3D_CPU;test_onnx_pytorch_convT3D;none;test_onnx_convT3D,--import,--cpu") ConvT3D not available in CPU
349354
pytorch2eddl+=("Pytorch_to_EDDL_upsample2D_CPU;test_onnx_pytorch_upsample2D;none;test_onnx_upsample2D,--import,--cpu")
355+
pytorch2eddl+=("Pytorch_to_EDDL_upsample3D_CPU;test_onnx_pytorch_upsample3D;none;test_onnx_upsample3D,--import,--cpu")
350356
pytorch2eddl+=("Pytorch_to_EDDL_LSTM_IMDB_CPU;test_onnx_pytorch_LSTM_imdb;none;test_onnx_lstm_imdb,--import,--cpu")
351357
pytorch2eddl+=("Pytorch_to_EDDL_GRU_IMDB_CPU;test_onnx_pytorch_GRU_imdb;none;test_onnx_gru_imdb,--import,--cpu")
352358
pytorch2eddl+=("Pytorch_to_EDDL_LSTM_MNIST_CPU;test_onnx_pytorch_LSTM_mnist;none;test_onnx_lstm_mnist,--import,--cpu")
@@ -427,6 +433,7 @@ then
427433
keras2eddl+=("Keras_to_EDDL_convT2D;test_onnx_keras_convT2D;export_scripts/convT2D_enc_dec_mnist_keras_export.py;test_onnx_convT2D,--import")
428434
keras2eddl+=("Keras_to_EDDL_convT3D;test_onnx_keras_convT3D;export_scripts/convT3D_enc_dec_synthetic_keras_export.py;test_onnx_convT3D,--import,--channels-last")
429435
keras2eddl+=("Keras_to_EDDL_upsample2D;test_onnx_keras_upsample2D;export_scripts/upsample2D_enc_dec_mnist_keras_export.py;test_onnx_upsample2D,--import")
436+
keras2eddl+=("Keras_to_EDDL_upsample3D;test_onnx_keras_upsample3D;export_scripts/upsample3D_enc_dec_synthetic_keras_export.py;test_onnx_upsample3D,--import")
430437
keras2eddl+=("Keras_to_EDDL_LSTM_IMDB;test_onnx_keras_LSTM_imdb;export_scripts/lstm_keras_export.py;test_onnx_lstm_imdb,--import")
431438
keras2eddl+=("Keras_to_EDDL_GRU_IMDB;test_onnx_keras_GRU_imdb;export_scripts/gru_keras_export.py;test_onnx_gru_imdb,--import")
432439
keras2eddl+=("Keras_to_EDDL_LSTM_MNIST;test_onnx_keras_LSTM_mnist;export_scripts/lstm_mnist_keras_export.py;test_onnx_lstm_mnist,--import")
@@ -443,6 +450,7 @@ then
443450
#keras2eddl+=("Keras_to_EDDL_convT2D_CPU;test_onnx_keras_convT2D;none;test_onnx_convT2D,--import,--cpu") ConvT2D not available in CPU
444451
#keras2eddl+=("Keras_to_EDDL_convT3D_CPU;test_onnx_keras_convT3D;none;test_onnx_convT3D,--import,--cpu,--channels-last") ConvT3D not available in CPU
445452
keras2eddl+=("Keras_to_EDDL_upsample2D_CPU;test_onnx_keras_upsample2D;none;test_onnx_upsample2D,--import,--cpu")
453+
keras2eddl+=("Keras_to_EDDL_upsample3D_CPU;test_onnx_keras_upsample3D;none;test_onnx_upsample3D,--import,--cpu")
446454
keras2eddl+=("Keras_to_EDDL_LSTM_IMDB_CPU;test_onnx_keras_LSTM_imdb;none;test_onnx_lstm_imdb,--import,--cpu")
447455
keras2eddl+=("Keras_to_EDDL_GRU_IMDB_CPU;test_onnx_keras_GRU_imdb;none;test_onnx_gru_imdb,--import,--cpu")
448456
keras2eddl+=("Keras_to_EDDL_LSTM_MNIST_CPU;test_onnx_keras_LSTM_mnist;none;test_onnx_lstm_mnist,--import,--cpu")

0 commit comments

Comments
 (0)