@@ -3257,6 +3257,31 @@ def func(filter_val, out_backprop_val):
3257
3257
out_backprop_val = np .random .randint (low = 0 , high = 256 , size = [1 , 5 , 5 , 5 ]).astype (np .float32 )
3258
3258
self ._run_test_case (func , [_OUTPUT ], {_INPUT : filters_val , _INPUT1 : out_backprop_val })
3259
3259
3260
+ @check_tf_min_version ("1.15" , "tf.repeat needs tf 1.15" )
3261
+ @check_opset_min_version (10 , "Conv2DBackpropInput" )
3262
+ def test_Conv2DBackpropInput_shape_implied (self ):
3263
+ batch_dim_val = np .array (1 , dtype = np .int32 )
3264
+ def func (filter_val , out_backprop_val , batch_dim ):
3265
+ out_backprop_val = tf .repeat (out_backprop_val , batch_dim , axis = 0 )
3266
+ s = tf .shape (out_backprop_val )
3267
+ t1 = tf .constant ([0 ], dtype = tf .int32 )
3268
+ t2 = tf .constant ([1 ], dtype = tf .int32 )
3269
+ batch_dim = tf .strided_slice (s , t1 , t2 , shrink_axis_mask = 1 )
3270
+ # Sometimes the size given is a stack of constants with unknown batch dim
3271
+ input_sizes_val = tf .stack ([batch_dim , 10 , 10 , 3 ])
3272
+ return conv2d_backprop_input (input_sizes = input_sizes_val , filter = filter_val ,
3273
+ out_backprop = out_backprop_val , strides = [1 , 2 , 2 , 1 ],
3274
+ padding = 'SAME' , name = _TFOUTPUT )
3275
+ filters_val = np .random .randint (low = 0 , high = 256 , size = [3 , 3 , 3 , 5 ]).astype (np .float32 )
3276
+ out_backprop_val = np .random .randint (low = 0 , high = 256 , size = [1 , 5 , 5 , 5 ]).astype (np .float32 )
3277
+ def graph_validator (g ):
3278
+ for n in g .get_nodes ():
3279
+ if n .type == 'ConvTranspose' :
3280
+ return "output_shape" in n .attr
3281
+ return False
3282
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : filters_val , _INPUT1 : out_backprop_val , _INPUT2 : batch_dim_val },
3283
+ graph_validator = graph_validator )
3284
+
3260
3285
@check_opset_min_version (10 , "Conv2DBackpropInput" )
3261
3286
def test_Conv2DBackpropInput_const_valid (self ):
3262
3287
input_sizes_val_ = np .array ([1 , 12 , 12 , 3 ], dtype = np .int32 )
0 commit comments