Skip to content

Commit 9ca22d8

Browse files
committed
Added adaptive average pooling (only to 1x1) as a global average pooling.
1 parent ca02f31 commit 9ca22d8

File tree

1 file changed

+40
-8
lines changed

1 file changed

+40
-8
lines changed

pytorch2keras/layers.py

Lines changed: 40 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -630,10 +630,14 @@ def convert_reshape(params, w_name, scope_name, inputs, layers, weights):
630630
weights: pytorch state_dict
631631
"""
632632
print('Converting reshape ...')
633-
634-
tf_name = w_name + str(random.random())
635-
reshape = keras.layers.Reshape(params['shape'][1:], name=tf_name)
636-
layers[scope_name] = reshape(layers[inputs[0]])
633+
if len(inputs) > 1:
634+
tf_name = w_name + str(random.random())
635+
reshape = keras.layers.Reshape(layers[inputs[1]][1:], name=tf_name)
636+
layers[scope_name] = reshape(layers[inputs[0]])
637+
else:
638+
tf_name = w_name + str(random.random())
639+
reshape = keras.layers.Reshape(params['shape'][1:], name=tf_name)
640+
layers[scope_name] = reshape(layers[inputs[0]])
637641

638642

639643
def convert_matmul(params, w_name, scope_name, inputs, layers, weights):
@@ -750,11 +754,12 @@ def convert_constant(params, w_name, scope_name, inputs, layers, weights):
750754
"""
751755
print('Converting constant ...')
752756

753-
def target_layer(params=params):
754-
return keras.backend.constant(np.float32(params['value']))
757+
# def target_layer(x, params=params):
758+
# return keras.backend.constant(np.float32(params['value']))
755759

756-
lambda_layer = keras.layers.Lambda(target_layer)
757-
layers[scope_name] = lambda_layer(layers[inputs[0]])
760+
# lambda_layer = keras.layers.Lambda(target_layer)
761+
# layers[scope_name] = lambda_layer(layers[inputs[0]])
762+
layers[scope_name] = np.float32(params['value'])
758763

759764

760765
def convert_upsample(params, w_name, scope_name, inputs, layers, weights):
@@ -815,6 +820,32 @@ def convert_padding(params, w_name, scope_name, inputs, layers, weights):
815820
layers[scope_name] = padding_layer(layers[inputs[0]])
816821

817822

823+
824+
def convert_adaptive_avg_pool2d(params, w_name, scope_name, inputs, layers, weights):
825+
"""
826+
Convert adaptive_avg_pool2d layer.
827+
828+
Args:
829+
params: dictionary with layer parameters
830+
w_name: name prefix in state_dict
831+
scope_name: pytorch scope name
832+
inputs: pytorch node inputs
833+
layers: dictionary with keras tensors
834+
weights: pytorch state_dict
835+
"""
836+
print('Converting adaptive_avg_pool2d...')
837+
838+
tf_name = w_name + str(random.random())
839+
global_pool = keras.layers.GlobalAveragePooling2D()
840+
layers_global_pool = global_pool(layers[inputs[0]])
841+
842+
def target_layer(x):
843+
return keras.backend.expand_dims(x)
844+
845+
lambda_layer = keras.layers.Lambda(target_layer)
846+
layers[scope_name] = lambda_layer(layers_global_pool)
847+
848+
818849
AVAILABLE_CONVERTERS = {
819850
'onnx::Conv': convert_conv,
820851
'onnx::ConvTranspose': convert_convtranspose,
@@ -844,4 +875,5 @@ def convert_padding(params, w_name, scope_name, inputs, layers, weights):
844875
'onnx::Constant': convert_constant,
845876
'onnx::Upsample': convert_upsample,
846877
'onnx::Pad': convert_padding,
878+
'aten::adaptive_avg_pool2d': convert_adaptive_avg_pool2d,
847879
}

0 commit comments

Comments
 (0)