@@ -87,6 +87,17 @@ def get_conv_getdata(kind=1):
87
87
else :
88
88
raise ValueError ("kind not known" )
89
89
90
+ def get_maxpoolwithargmax_getdata ():
91
+ data = [
92
+ ('SAME' , [1 , 3 , 3 , 1 ], [1 , 3 , 3 , 1 ], [1 , 2 , 2 , 1 ]),
93
+ ('SAME' , [1 , 5 , 5 , 1 ], [1 , 4 , 4 , 1 ], [1 , 2 , 2 , 1 ]),
94
+ ('SAME' , [1 , 10 , 5 , 1 ], [1 , 2 , 2 , 1 ], [1 , 2 , 2 , 1 ]),
95
+ ('SAME' , [1 , 10 , 5 , 1 ], [1 , 4 , 4 , 1 ], [1 , 1 , 1 , 1 ]),
96
+ ('VALID' , [1 , 3 , 3 , 1 ], [1 , 3 , 3 , 1 ], [1 , 2 , 2 , 1 ]),
97
+ ('VALID' , [1 , 5 , 5 , 1 ], [1 , 4 , 4 , 1 ], [1 , 2 , 2 , 1 ]),
98
+ ]
99
+ for idx , v in enumerate (data ):
100
+ yield (idx ,) + v
90
101
91
102
class BackendTests (Tf2OnnxBackendTestBase ):
92
103
def _run_test_case (self , output_names_with_port , feed_dict , ** kwargs ):
@@ -2236,6 +2247,24 @@ def test_thresholded_relu(self):
2236
2247
graph_validator = lambda g : check_op_count (g , "ThresholdedRelu" , 1 ))
2237
2248
tf .reset_default_graph ()
2238
2249
2250
+ @check_tf_min_version ("1.13" )
2251
+ @check_opset_min_version (8 , "MaxPoolWithArgmax" )
2252
+ def test_maxpoolwithargmax (self ):
2253
+ for tf_shape in ["known" , "unknown" ]:
2254
+ tf .reset_default_graph ()
2255
+ for p in get_maxpoolwithargmax_getdata ():
2256
+ _ , padding , x_shape , ksize , strides = p
2257
+ tf .reset_default_graph ()
2258
+ x_val = make_xval (x_shape )
2259
+ if tf_shape == "known" :
2260
+ x = tf .placeholder (tf .float32 , shape = x_val .shape , name = _TFINPUT )
2261
+ else :
2262
+ x = tf .placeholder (tf .float32 , shape = [None ] * x_val .ndim , name = _TFINPUT )
2263
+ mp = tf .nn .max_pool_with_argmax (x , ksize , strides , padding = padding )
2264
+ _ = tf .identity (mp [0 ], name = _TFOUTPUT )
2265
+ _ = tf .identity (mp [1 ], name = _TFOUTPUT1 )
2266
+ self .logger .debug (str (p ))
2267
+ self ._run_test_case ([_OUTPUT , _OUTPUT1 ], {_INPUT : x_val })
2239
2268
2240
2269
if __name__ == '__main__' :
2241
2270
unittest_main ()
0 commit comments