@@ -630,10 +630,14 @@ def convert_reshape(params, w_name, scope_name, inputs, layers, weights):
630
630
weights: pytorch state_dict
631
631
"""
632
632
print ('Converting reshape ...' )
633
-
634
- tf_name = w_name + str (random .random ())
635
- reshape = keras .layers .Reshape (params ['shape' ][1 :], name = tf_name )
636
- layers [scope_name ] = reshape (layers [inputs [0 ]])
633
+ if len (inputs ) > 1 :
634
+ tf_name = w_name + str (random .random ())
635
+ reshape = keras .layers .Reshape (layers [inputs [1 ]][1 :], name = tf_name )
636
+ layers [scope_name ] = reshape (layers [inputs [0 ]])
637
+ else :
638
+ tf_name = w_name + str (random .random ())
639
+ reshape = keras .layers .Reshape (params ['shape' ][1 :], name = tf_name )
640
+ layers [scope_name ] = reshape (layers [inputs [0 ]])
637
641
638
642
639
643
def convert_matmul (params , w_name , scope_name , inputs , layers , weights ):
@@ -750,11 +754,12 @@ def convert_constant(params, w_name, scope_name, inputs, layers, weights):
750
754
"""
751
755
print ('Converting constant ...' )
752
756
753
- def target_layer (params = params ):
754
- return keras .backend .constant (np .float32 (params ['value' ]))
757
+ # def target_layer(x, params=params):
758
+ # return keras.backend.constant(np.float32(params['value']))
755
759
756
- lambda_layer = keras .layers .Lambda (target_layer )
757
- layers [scope_name ] = lambda_layer (layers [inputs [0 ]])
760
+ # lambda_layer = keras.layers.Lambda(target_layer)
761
+ # layers[scope_name] = lambda_layer(layers[inputs[0]])
762
+ layers [scope_name ] = np .float32 (params ['value' ])
758
763
759
764
760
765
def convert_upsample (params , w_name , scope_name , inputs , layers , weights ):
@@ -815,6 +820,32 @@ def convert_padding(params, w_name, scope_name, inputs, layers, weights):
815
820
layers [scope_name ] = padding_layer (layers [inputs [0 ]])
816
821
817
822
823
+
824
+ def convert_adaptive_avg_pool2d (params , w_name , scope_name , inputs , layers , weights ):
825
+ """
826
+ Convert adaptive_avg_pool2d layer.
827
+
828
+ Args:
829
+ params: dictionary with layer parameters
830
+ w_name: name prefix in state_dict
831
+ scope_name: pytorch scope name
832
+ inputs: pytorch node inputs
833
+ layers: dictionary with keras tensors
834
+ weights: pytorch state_dict
835
+ """
836
+ print ('Converting adaptive_avg_pool2d...' )
837
+
838
+ tf_name = w_name + str (random .random ())
839
+ global_pool = keras .layers .GlobalAveragePooling2D ()
840
+ layers_global_pool = global_pool (layers [inputs [0 ]])
841
+
842
+ def target_layer (x ):
843
+ return keras .backend .expand_dims (x )
844
+
845
+ lambda_layer = keras .layers .Lambda (target_layer )
846
+ layers [scope_name ] = lambda_layer (layers_global_pool )
847
+
848
+
818
849
AVAILABLE_CONVERTERS = {
819
850
'onnx::Conv' : convert_conv ,
820
851
'onnx::ConvTranspose' : convert_convtranspose ,
@@ -844,4 +875,5 @@ def convert_padding(params, w_name, scope_name, inputs, layers, weights):
844
875
'onnx::Constant' : convert_constant ,
845
876
'onnx::Upsample' : convert_upsample ,
846
877
'onnx::Pad' : convert_padding ,
878
+ 'aten::adaptive_avg_pool2d' : convert_adaptive_avg_pool2d ,
847
879
}
0 commit comments