Skip to content

Commit 5b103c2

Browse files
authored
Simplify multi_box_head API in detection.py and remove assign op. (#18310) (#18388)
* Simplify multi_box_head API in detection.py and remove assign op.
1 parent 08b0a7b commit 5b103c2

File tree

2 files changed

+34
-36
lines changed

2 files changed

+34
-36
lines changed

paddle/fluid/operators/detection/mine_hard_examples_op.cc

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -195,25 +195,31 @@ class MineHardExamplesOp : public framework::OperatorWithKernel {
195195
auto loc_loss_dims = ctx->GetInputDim("LocLoss");
196196
PADDLE_ENFORCE_EQ(loc_loss_dims.size(), 2UL,
197197
"The shape of LocLoss is [N, Np].");
198-
PADDLE_ENFORCE_EQ(cls_loss_dims[0], loc_loss_dims[0],
199-
"Batch size of ClsLoss and LocLoss must be the same.");
200-
PADDLE_ENFORCE_EQ(
201-
cls_loss_dims[1], loc_loss_dims[1],
202-
"Prior box number of ClsLoss and LocLoss must be the same.");
198+
if (ctx->IsRuntime()) {
199+
PADDLE_ENFORCE_EQ(
200+
cls_loss_dims[0], loc_loss_dims[0],
201+
"Batch size of ClsLoss and LocLoss must be the same.");
202+
PADDLE_ENFORCE_EQ(
203+
cls_loss_dims[1], loc_loss_dims[1],
204+
"Prior box number of ClsLoss and LocLoss must be the same.");
205+
}
203206
}
204207

205-
PADDLE_ENFORCE_EQ(
206-
cls_loss_dims[0], idx_dims[0],
207-
"Batch size of ClsLoss and MatchIndices must be the same.");
208-
PADDLE_ENFORCE_EQ(
209-
cls_loss_dims[1], idx_dims[1],
210-
"Prior box number of ClsLoss and MatchIndices must be the same.");
211-
212-
PADDLE_ENFORCE_EQ(cls_loss_dims[0], dis_dims[0],
213-
"Batch size of ClsLoss and MatchDist must be the same.");
214-
PADDLE_ENFORCE_EQ(
215-
cls_loss_dims[1], idx_dims[1],
216-
"Prior box number of ClsLoss and MatchDist must be the same.");
208+
if (ctx->IsRuntime()) {
209+
PADDLE_ENFORCE_EQ(
210+
cls_loss_dims[0], idx_dims[0],
211+
"Batch size of ClsLoss and MatchIndices must be the same.");
212+
PADDLE_ENFORCE_EQ(
213+
cls_loss_dims[1], idx_dims[1],
214+
"Prior box number of ClsLoss and MatchIndices must be the same.");
215+
216+
PADDLE_ENFORCE_EQ(
217+
cls_loss_dims[0], dis_dims[0],
218+
"Batch size of ClsLoss and MatchDist must be the same.");
219+
PADDLE_ENFORCE_EQ(
220+
cls_loss_dims[1], idx_dims[1],
221+
"Prior box number of ClsLoss and MatchDist must be the same.");
222+
}
217223

218224
auto mining_type =
219225
GetMiningType(ctx->Attrs().Get<std::string>("mining_type"));

python/paddle/fluid/layers/detection.py

Lines changed: 11 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1393,8 +1393,10 @@ def __reshape_to_2d(var):
13931393
# 3. Mining hard examples
13941394
actual_shape = nn.slice(conf_shape, axes=[0], starts=[0], ends=[2])
13951395
actual_shape.stop_gradient = True
1396+
# shape=(-1, 0) is set for compile-time, the correct shape is set by
1397+
# actual_shape in runtime.
13961398
conf_loss = nn.reshape(
1397-
x=conf_loss, shape=(num, num_prior), actual_shape=actual_shape)
1399+
x=conf_loss, shape=(-1, 0), actual_shape=actual_shape)
13981400
conf_loss.stop_gradient = True
13991401
neg_indices = helper.create_variable_for_type_inference(dtype='int32')
14001402
dtype = matched_indices.dtype
@@ -1464,7 +1466,9 @@ def __reshape_to_2d(var):
14641466
# 5.3 Compute overall weighted loss.
14651467
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
14661468
# reshape to [N, Np], N is the batch size and Np is the prior box number.
1467-
loss = nn.reshape(x=loss, shape=(num, num_prior), actual_shape=actual_shape)
1469+
# shape=(-1, 0) is set for compile-time, the correct shape is set by
1470+
# actual_shape in runtime.
1471+
loss = nn.reshape(x=loss, shape=(-1, 0), actual_shape=actual_shape)
14681472
loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
14691473
if normalize:
14701474
normalizer = nn.reduce_sum(target_loc_weight)
@@ -1927,13 +1931,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
19271931
stride=stride)
19281932

19291933
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
1930-
compile_shape = [
1931-
mbox_loc.shape[0], cpt.floor_division(
1932-
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3], 4), 4
1933-
]
1934-
run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32"))
1935-
mbox_loc_flatten = nn.reshape(
1936-
mbox_loc, shape=compile_shape, actual_shape=run_shape)
1934+
mbox_loc_flatten = nn.flatten(mbox_loc, axis=1)
19371935
mbox_locs.append(mbox_loc_flatten)
19381936

19391937
# get conf
@@ -1945,16 +1943,7 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
19451943
padding=pad,
19461944
stride=stride)
19471945
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
1948-
new_shape = [0, -1, num_classes]
1949-
compile_shape = [
1950-
conf_loc.shape[0],
1951-
cpt.floor_division(conf_loc.shape[1] * conf_loc.shape[2] *
1952-
conf_loc.shape[3], num_classes), num_classes
1953-
]
1954-
run_shape = tensor.assign(
1955-
numpy.array([0, -1, num_classes]).astype("int32"))
1956-
conf_loc_flatten = nn.reshape(
1957-
conf_loc, shape=compile_shape, actual_shape=run_shape)
1946+
conf_loc_flatten = nn.flatten(conf_loc, axis=1)
19581947
mbox_confs.append(conf_loc_flatten)
19591948

19601949
if len(box_results) == 1:
@@ -1972,7 +1961,10 @@ def _is_list_or_tuple_and_equal(data, length, err_info):
19721961
box = tensor.concat(reshaped_boxes)
19731962
var = tensor.concat(reshaped_vars)
19741963
mbox_locs_concat = tensor.concat(mbox_locs, axis=1)
1964+
mbox_locs_concat = nn.reshape(mbox_locs_concat, shape=[0, -1, 4])
19751965
mbox_confs_concat = tensor.concat(mbox_confs, axis=1)
1966+
mbox_confs_concat = nn.reshape(
1967+
mbox_confs_concat, shape=[0, -1, num_classes])
19761968

19771969
box.stop_gradient = True
19781970
var.stop_gradient = True

0 commit comments

Comments
 (0)