Skip to content

Commit fa7ace7

Browse files
authored
Cherry pick from #21862 (#22194)
* Fix default label dim of label_smooth_op. test=develop (#21862) * Fix unit tests of label_smooth_op's data size.
1 parent c7248cd commit fa7ace7

File tree

4 files changed

+24
-6
lines changed

4 files changed

+24
-6
lines changed

paddle/fluid/operators/label_smooth_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class LabelSmoothOp : public framework::OperatorWithKernel {
3737
auto noise_dims = ctx->GetInputDim("PriorDist");
3838
auto noise_numel = paddle::framework::product(noise_dims);
3939
PADDLE_ENFORCE(
40-
in_dims[1] == noise_numel,
40+
in_dims[in_dims.size() - 1] == noise_numel,
4141
"The number of elements in Input(PriorDist) must be equal to the "
4242
"dimension of each label.");
4343
}

paddle/fluid/operators/label_smooth_op.cu

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ __global__ void LabelSmoothRunDistKernel(const int N, const float epsilon,
3434
const T* dist_data, T* dst) {
3535
int idx = blockDim.x * blockIdx.x + threadIdx.x;
3636
for (; idx < N; idx += blockDim.x * gridDim.x) {
37-
int dist_idx = idx - (idx / dist_numel) * dist_numel;
37+
int dist_idx = idx % dist_numel;
3838
dst[idx] = static_cast<T>(1 - epsilon) * src[idx] +
3939
static_cast<T>(epsilon) * dist_data[dist_idx];
4040
}
@@ -56,7 +56,7 @@ class LabelSmoothGPUKernel : public framework::OpKernel<T> {
5656
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
5757
auto* in_t = ctx.Input<framework::LoDTensor>("X");
5858
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
59-
auto label_dim = in_t->dims()[1];
59+
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
6060
auto epsilon = ctx.Attr<float>("epsilon");
6161
auto& dev = *ctx.template device_context<DeviceContext>().eigen_device();
6262
auto size_prob = in_t->numel();

paddle/fluid/operators/label_smooth_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
2727
auto* out_t = ctx.Output<framework::LoDTensor>("Out");
2828
auto* in_t = ctx.Input<framework::LoDTensor>("X");
2929
auto* dist_t = ctx.Input<framework::Tensor>("PriorDist");
30-
auto label_dim = in_t->dims()[1];
30+
auto label_dim = in_t->dims()[in_t->dims().size() - 1];
3131
out_t->mutable_data<T>(ctx.GetPlace());
3232

3333
auto epsilon = ctx.Attr<float>("epsilon");
@@ -39,7 +39,7 @@ class LabelSmoothKernel : public framework::OpKernel<T> {
3939
out.device(dev) =
4040
static_cast<T>(1 - epsilon) * in +
4141
static_cast<T>(epsilon) *
42-
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel()));
42+
dist.broadcast(Eigen::DSizes<int, 1>(in_t->numel() / label_dim));
4343
} else {
4444
out.device(dev) = static_cast<T>(1 - epsilon) * in +
4545
static_cast<T>(epsilon / label_dim);

python/paddle/fluid/tests/unittests/test_label_smooth_op.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class TestLabelSmoothOp(OpTest):
2323
def config(self):
2424
self.op_type = "label_smooth"
2525
self.epsilon = 0.1
26-
batch_size, self.label_dim = 5, 10
26+
batch_size, self.label_dim = 10, 12
2727
self.label = np.zeros((batch_size, self.label_dim)).astype("float64")
2828
nonzero_index = np.random.randint(self.label_dim, size=(batch_size))
2929
self.label[np.arange(batch_size), nonzero_index] = 1
@@ -53,5 +53,23 @@ def setUp(self):
5353
self.outputs = {'Out': smoothed_label}
5454

5555

56+
class TestLabelSmoothOp3D(TestLabelSmoothOp):
57+
def setUp(self):
58+
super(TestLabelSmoothOp3D, self).setUp()
59+
self.inputs['X'] = self.inputs['X'].reshape(
60+
[2, -1, self.inputs['X'].shape[-1]])
61+
self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X']
62+
.shape)
63+
64+
65+
class TestLabelSmoothOpWithPriorDist3D(TestLabelSmoothOpWithPriorDist):
66+
def setUp(self):
67+
super(TestLabelSmoothOpWithPriorDist3D, self).setUp()
68+
self.inputs['X'] = self.inputs['X'].reshape(
69+
[2, -1, self.inputs['X'].shape[-1]])
70+
self.outputs['Out'] = self.outputs['Out'].reshape(self.inputs['X']
71+
.shape)
72+
73+
5674
if __name__ == '__main__':
5775
unittest.main()

0 commit comments

Comments
 (0)