@@ -3282,6 +3282,31 @@ def func(filter_val, out_backprop_val):
3282
3282
out_backprop_val = np .random .randint (low = 0 , high = 256 , size = [1 , 5 , 5 , 5 ]).astype (np .float32 )
3283
3283
self ._run_test_case (func , [_OUTPUT ], {_INPUT : filters_val , _INPUT1 : out_backprop_val })
3284
3284
3285
+ @check_tf_min_version ("1.15" , "tf.repeat needs tf 1.15" )
3286
+ @check_opset_min_version (10 , "Conv2DBackpropInput" )
3287
+ def test_Conv2DBackpropInput_shape_implied (self ):
3288
+ batch_dim_val = np .array (1 , dtype = np .int32 )
3289
+ def func (filter_val , out_backprop_val , batch_dim ):
3290
+ out_backprop_val = tf .repeat (out_backprop_val , batch_dim , axis = 0 )
3291
+ s = tf .shape (out_backprop_val )
3292
+ t1 = tf .constant ([0 ], dtype = tf .int32 )
3293
+ t2 = tf .constant ([1 ], dtype = tf .int32 )
3294
+ batch_dim = tf .strided_slice (s , t1 , t2 , shrink_axis_mask = 1 )
3295
+ # Sometimes the size given is a stack of constants with unknown batch dim
3296
+ input_sizes_val = tf .stack ([batch_dim , 10 , 10 , 3 ])
3297
+ return conv2d_backprop_input (input_sizes = input_sizes_val , filter = filter_val ,
3298
+ out_backprop = out_backprop_val , strides = [1 , 2 , 2 , 1 ],
3299
+ padding = 'SAME' , name = _TFOUTPUT )
3300
+ filters_val = np .random .randint (low = 0 , high = 256 , size = [3 , 3 , 3 , 5 ]).astype (np .float32 )
3301
+ out_backprop_val = np .random .randint (low = 0 , high = 256 , size = [1 , 5 , 5 , 5 ]).astype (np .float32 )
3302
+ def graph_validator (g ):
3303
+ for n in g .get_nodes ():
3304
+ if n .type == 'ConvTranspose' :
3305
+ return "output_shape" in n .attr
3306
+ return False
3307
+ self ._run_test_case (func , [_OUTPUT ], {_INPUT : filters_val , _INPUT1 : out_backprop_val , _INPUT2 : batch_dim_val },
3308
+ graph_validator = graph_validator )
3309
+
3285
3310
@check_opset_min_version (10 , "Conv2DBackpropInput" )
3286
3311
def test_Conv2DBackpropInput_const_valid (self ):
3287
3312
input_sizes_val_ = np .array ([1 , 12 , 12 , 3 ], dtype = np .int32 )
0 commit comments