@@ -1393,8 +1393,10 @@ def __reshape_to_2d(var):
1393
1393
# 3. Mining hard examples
1394
1394
actual_shape = nn .slice (conf_shape , axes = [0 ], starts = [0 ], ends = [2 ])
1395
1395
actual_shape .stop_gradient = True
1396
+ # shape=(-1, 0) is set for compile-time, the correct shape is set by
1397
+ # actual_shape in runtime.
1396
1398
conf_loss = nn .reshape (
1397
- x = conf_loss , shape = (num , num_prior ), actual_shape = actual_shape )
1399
+ x = conf_loss , shape = (- 1 , 0 ), actual_shape = actual_shape )
1398
1400
conf_loss .stop_gradient = True
1399
1401
neg_indices = helper .create_variable_for_type_inference (dtype = 'int32' )
1400
1402
dtype = matched_indices .dtype
@@ -1464,7 +1466,9 @@ def __reshape_to_2d(var):
1464
1466
# 5.3 Compute overall weighted loss.
1465
1467
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
1466
1468
# reshape to [N, Np], N is the batch size and Np is the prior box number.
1467
- loss = nn .reshape (x = loss , shape = (num , num_prior ), actual_shape = actual_shape )
1469
+ # shape=(-1, 0) is set for compile-time, the correct shape is set by
1470
+ # actual_shape in runtime.
1471
+ loss = nn .reshape (x = loss , shape = (- 1 , 0 ), actual_shape = actual_shape )
1468
1472
loss = nn .reduce_sum (loss , dim = 1 , keep_dim = True )
1469
1473
if normalize :
1470
1474
normalizer = nn .reduce_sum (target_loc_weight )
@@ -1927,13 +1931,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
1927
1931
stride = stride )
1928
1932
1929
1933
mbox_loc = nn .transpose (mbox_loc , perm = [0 , 2 , 3 , 1 ])
1930
- compile_shape = [
1931
- mbox_loc .shape [0 ], cpt .floor_division (
1932
- mbox_loc .shape [1 ] * mbox_loc .shape [2 ] * mbox_loc .shape [3 ], 4 ), 4
1933
- ]
1934
- run_shape = tensor .assign (numpy .array ([0 , - 1 , 4 ]).astype ("int32" ))
1935
- mbox_loc_flatten = nn .reshape (
1936
- mbox_loc , shape = compile_shape , actual_shape = run_shape )
1934
+ mbox_loc_flatten = nn .flatten (mbox_loc , axis = 1 )
1937
1935
mbox_locs .append (mbox_loc_flatten )
1938
1936
1939
1937
# get conf
@@ -1945,16 +1943,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
1945
1943
padding = pad ,
1946
1944
stride = stride )
1947
1945
conf_loc = nn .transpose (conf_loc , perm = [0 , 2 , 3 , 1 ])
1948
- new_shape = [0 , - 1 , num_classes ]
1949
- compile_shape = [
1950
- conf_loc .shape [0 ],
1951
- cpt .floor_division (conf_loc .shape [1 ] * conf_loc .shape [2 ] *
1952
- conf_loc .shape [3 ], num_classes ), num_classes
1953
- ]
1954
- run_shape = tensor .assign (
1955
- numpy .array ([0 , - 1 , num_classes ]).astype ("int32" ))
1956
- conf_loc_flatten = nn .reshape (
1957
- conf_loc , shape = compile_shape , actual_shape = run_shape )
1946
+ conf_loc_flatten = nn .flatten (conf_loc , axis = 1 )
1958
1947
mbox_confs .append (conf_loc_flatten )
1959
1948
1960
1949
if len (box_results ) == 1 :
@@ -1972,7 +1961,10 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
1972
1961
box = tensor .concat (reshaped_boxes )
1973
1962
var = tensor .concat (reshaped_vars )
1974
1963
mbox_locs_concat = tensor .concat (mbox_locs , axis = 1 )
1964
+ mbox_locs_concat = nn .reshape (mbox_locs_concat , shape = [0 , - 1 , 4 ])
1975
1965
mbox_confs_concat = tensor .concat (mbox_confs , axis = 1 )
1966
+ mbox_confs_concat = nn .reshape (
1967
+ mbox_confs_concat , shape = [0 , - 1 , num_classes ])
1976
1968
1977
1969
box .stop_gradient = True
1978
1970
var .stop_gradient = True
0 commit comments