Skip to content

Commit 733674c

Browse files
committed
Fixed double-flatten issue.
1 parent 79bd571 commit 733674c

File tree

2 files changed

+8
-3
lines changed

2 files changed

+8
-3
lines changed

pytorch2keras/layers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -232,14 +232,19 @@ def convert_flatten(params, w_name, scope_name, inputs, layers, weights, short_n
232232
short_names: use short names for keras layers
233233
"""
234234
print('Conerting reshape ...')
235+
235236
if short_names:
236237
tf_name = 'R' + random_string(7)
237238
else:
238239
tf_name = w_name + str(random.random())
239240

240241
# TODO: check if the input is already flattened
241-
reshape = keras.layers.Flatten(name=tf_name)
242-
layers[scope_name] = reshape(layers[inputs[0]])
242+
# Ad-hoc to avoid it:
243+
if len(list(layers[inputs[0]].shape)) == 2:
244+
layers[scope_name] = layers[inputs[0]]
245+
else:
246+
reshape = keras.layers.Flatten(name=tf_name)
247+
layers[scope_name] = reshape(layers[inputs[0]])
243248

244249

245250
def convert_gemm(params, w_name, scope_name, inputs, layers, weights, short_names):

tests/view.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ def __init__(self, inp=10, out=16, kernel_size=3, bias=True):
1515

1616
def forward(self, x):
1717
x = self.conv2d(x)
18-
x = x.view([x.size(0), -1, 2, 1, 1, 1, 1, 1])
18+
x = x.view([x.size(0), -1, 2, 1, 1, 1, 1, 1]).view(x.size(0), -1).view(x.size(0), -1)
1919
x = torch.nn.Tanh()(x)
2020
return x
2121

0 commit comments

Comments
 (0)