Skip to content

Commit 1a24862

Browse files
committed
Added upsampling layer support.
1 parent 83da7de commit 1a24862

File tree

2 files changed

+80
-0
lines changed

2 files changed

+80
-0
lines changed

pytorch2keras/layers.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -686,6 +686,32 @@ def convert_constant(params, w_name, scope_name, inputs, layers, weights):
686686
layers[scope_name] = lambda_layer(layers[inputs[0]])
687687

688688

689+
def convert_upsample(params, w_name, scope_name, inputs, layers, weights):
690+
"""
691+
Convert upsample_bilinear2d layer.
692+
693+
Args:
694+
params: dictionary with layer parameters
695+
w_name: name prefix in state_dict
696+
scope_name: pytorch scope name
697+
inputs: pytorch node inputs
698+
layers: dictionary with keras tensors
699+
weights: pytorch state_dict
700+
"""
701+
print('Converting upsample...')
702+
703+
if params['mode'] != 'nearest':
704+
raise AssertionError('Cannot convert non-nearest upsampling')
705+
706+
tf_name = w_name + str(random.random())
707+
708+
scale = (params['height_scale'], params['width_scale'])
709+
upsampling = keras.layers.UpSampling2D(
710+
size=scale, name=tf_name
711+
)
712+
layers[scope_name] = upsampling(layers[inputs[0]])
713+
714+
689715
AVAILABLE_CONVERTERS = {
690716
'Conv': convert_conv,
691717
'ConvTranspose': convert_convtranspose,
@@ -710,4 +736,5 @@ def convert_constant(params, w_name, scope_name, inputs, layers, weights):
710736
'Gather': convert_gather,
711737
'ReduceSum': convert_reduce_sum,
712738
'Constant': convert_constant,
739+
'Upsample': convert_upsample,
713740
}

tests/upsample_nearest.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
import keras # work around segfault
2+
import sys
3+
import numpy as np
4+
5+
import torch
6+
import torch.nn as nn
7+
import torch.nn.functional as F
8+
from torch.autograd import Variable
9+
10+
sys.path.append('../pytorch2keras')
11+
from converter import pytorch_to_keras
12+
13+
14+
class TestUpsampleNearest2d(nn.Module):
15+
"""Module for UpsampleNearest2d conversion testing
16+
"""
17+
18+
def __init__(self, inp=10, out=16, kernel_size=3, bias=True):
19+
super(TestUpsampleNearest2d, self).__init__()
20+
self.conv2d = nn.Conv2d(inp, out, kernel_size=kernel_size, bias=bias)
21+
self.up = nn.UpsamplingNearest2d(scale_factor=2)
22+
23+
def forward(self, x):
24+
x = self.conv2d(x)
25+
x = F.upsample(x, scale_factor=2)
26+
x = self.up(x)
27+
return x
28+
29+
30+
if __name__ == '__main__':
31+
max_error = 0
32+
for i in range(100):
33+
kernel_size = np.random.randint(1, 7)
34+
inp = np.random.randint(kernel_size + 1, 100)
35+
out = np.random.randint(1, 100)
36+
37+
model = TestUpsampleNearest2d(inp, out, kernel_size, inp % 2)
38+
39+
input_np = np.random.uniform(0, 1, (1, inp, inp, inp))
40+
input_var = Variable(torch.FloatTensor(input_np))
41+
output = model(input_var)
42+
43+
k_model = pytorch_to_keras(model, input_var, (inp, inp, inp,), verbose=True)
44+
45+
pytorch_output = output.data.numpy()
46+
keras_output = k_model.predict(input_np)
47+
48+
error = np.max(pytorch_output - keras_output)
49+
print(error)
50+
if max_error < error:
51+
max_error = error
52+
53+
print('Max error: {0}'.format(max_error))

0 commit comments

Comments
 (0)