Skip to content

Commit 1d89866

Browse files
authored
Merge pull request #7593 from JiayiFeng/dev_elementwise_scalar
Make elementwise_op supporting scalar input `Y`
2 parents 37a9437 + 1930961 commit 1d89866

10 files changed

+122
-35
lines changed

paddle/operators/elementwise_div_op.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,16 @@ limitations under the License. */
1919
namespace paddle {
2020
namespace operators {
2121

22+
template <typename T>
23+
struct DivFunctor {
24+
inline HOSTDEVICE T operator()(T a, T b) const { return a / b; }
25+
};
26+
2227
template <typename DeviceContext, typename T>
2328
class ElementwiseDivKernel : public framework::OpKernel<T> {
2429
public:
2530
void Compute(const framework::ExecutionContext& ctx) const override {
26-
ElementwiseCompute<EigenDivFunctor, DeviceContext, T>(ctx);
31+
ElementwiseComputeEx<DivFunctor<T>, DeviceContext, T>(ctx);
2732
}
2833
};
2934

paddle/operators/elementwise_mul_op.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21+
template <typename T>
22+
struct MulFunctor {
23+
inline HOSTDEVICE T operator()(T a, T b) const { return a * b; }
24+
};
25+
2126
template <typename DeviceContext, typename T>
2227
class ElementwiseMulKernel : public framework::OpKernel<T> {
2328
public:
2429
void Compute(const framework::ExecutionContext& ctx) const override {
25-
ElementwiseCompute<EigenMulFunctor, DeviceContext, T>(ctx);
30+
ElementwiseComputeEx<MulFunctor<T>, DeviceContext, T>(ctx);
2631
}
2732
};
2833

paddle/operators/elementwise_op_function.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,13 @@ void ElementwiseGradCompute(const framework::ExecutionContext& ctx) {
340340
return;
341341
}
342342

343+
if (y_dims.size() == 1 && y_dims[0] == 1) {
344+
// y is a scalar
345+
auto extended_dims = framework::vectorize(x_dims);
346+
extended_dims.push_back(1);
347+
x_dims = framework::make_ddim(extended_dims);
348+
}
349+
343350
int axis = ctx.Attr<int>("axis");
344351
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
345352

@@ -378,6 +385,13 @@ void ElementwiseComputeEx(const framework::ExecutionContext& ctx) {
378385
return;
379386
}
380387

388+
if (y_dims.size() == 1 && y_dims[0] == 1) {
389+
// y is a scalar
390+
auto extended_dims = framework::vectorize(x_dims);
391+
extended_dims.push_back(1);
392+
x_dims = framework::make_ddim(extended_dims);
393+
}
394+
381395
int axis = ctx.Attr<int>("axis");
382396
axis = (axis == -1 ? x_dims.size() - y_dims.size() : axis);
383397
PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(),

paddle/operators/elementwise_sub_op.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,16 @@ limitations under the License. */
1818
namespace paddle {
1919
namespace operators {
2020

21+
template <typename T>
22+
struct SubFunctor {
23+
inline HOSTDEVICE T operator()(T a, T b) const { return a - b; }
24+
};
25+
2126
template <typename DeviceContext, typename T>
2227
class ElementwiseSubKernel : public framework::OpKernel<T> {
2328
public:
2429
void Compute(const framework::ExecutionContext& ctx) const override {
25-
ElementwiseCompute<EigenSubFunctor, DeviceContext, T>(ctx);
30+
ElementwiseComputeEx<SubFunctor<T>, DeviceContext, T>(ctx);
2631
}
2732
};
2833

python/paddle/v2/fluid/tests/test_elementwise_add_op.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
22
#
3-
#Licensed under the Apache License, Version 2.0 (the "License");
4-
#you may not use this file except in compliance with the License.
5-
#You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
#Unless required by applicable law or agreed to in writing, software
10-
#distributed under the License is distributed on an "AS IS" BASIS,
11-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
#See the License for the specific language governing permissions and
13-
#limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414
import unittest
1515
import numpy as np
1616
from op_test import OpTest
@@ -40,6 +40,16 @@ def test_check_grad_ingore_y(self):
4040
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
4141

4242

43+
class TestElementwiseAddOp_scalar(TestElementwiseOp):
44+
def setUp(self):
45+
self.op_type = "elementwise_add"
46+
self.inputs = {
47+
'X': np.random.rand(2, 3, 4).astype(np.float32),
48+
'Y': np.random.rand(1).astype(np.float32)
49+
}
50+
self.outputs = {'Out': self.inputs['X'] + self.inputs['Y']}
51+
52+
4353
class TestElementwiseAddOp_Vector(TestElementwiseOp):
4454
def setUp(self):
4555
self.op_type = "elementwise_add"

python/paddle/v2/fluid/tests/test_elementwise_div_op.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
22
#
3-
#Licensed under the Apache License, Version 2.0 (the "License");
4-
#you may not use this file except in compliance with the License.
5-
#You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
#Unless required by applicable law or agreed to in writing, software
10-
#distributed under the License is distributed on an "AS IS" BASIS,
11-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
#See the License for the specific language governing permissions and
13-
#limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414
import unittest
1515
import numpy as np
1616
from op_test import OpTest
@@ -45,6 +45,16 @@ def test_check_grad_ingore_y(self):
4545
['X'], 'Out', max_relative_error=0.05, no_grad_set=set('Y'))
4646

4747

48+
class TestElementwiseDivOp_scalar(ElementwiseDivOp):
49+
def setUp(self):
50+
self.op_type = "elementwise_div"
51+
self.inputs = {
52+
'X': np.random.uniform(0.1, 1, [2, 3, 4]).astype(np.float32),
53+
'Y': np.random.uniform(0.1, 1, [1]).astype(np.float32)
54+
}
55+
self.outputs = {'Out': self.inputs['X'] / self.inputs['Y']}
56+
57+
4858
class TestElementwiseDivOp_Vector(ElementwiseDivOp):
4959
def setUp(self):
5060
self.op_type = "elementwise_div"

python/paddle/v2/fluid/tests/test_elementwise_max_op.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def test_check_grad_ingore_y(self):
4343
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
4444

4545

46+
class TestElementwiseMaxOp_scalar(TestElementwiseOp):
47+
def setUp(self):
48+
self.op_type = "elementwise_max"
49+
x = np.random.random_integers(-5, 5, [2, 3, 4]).astype("float32")
50+
y = np.array([0.5]).astype("float32")
51+
self.inputs = {'X': x, 'Y': y}
52+
self.outputs = {'Out': np.maximum(self.inputs['X'], self.inputs['Y'])}
53+
54+
4655
class TestElementwiseMaxOp_Vector(TestElementwiseOp):
4756
def setUp(self):
4857
self.op_type = "elementwise_max"

python/paddle/v2/fluid/tests/test_elementwise_min_op.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,15 @@ def test_check_grad_ingore_y(self):
4343
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
4444

4545

46+
class TestElementwiseMinOp_scalar(TestElementwiseOp):
47+
def setUp(self):
48+
self.op_type = "elementwise_min"
49+
x = np.random.random_integers(-5, 5, [2, 3, 4]).astype("float32")
50+
y = np.array([0.5]).astype("float32")
51+
self.inputs = {'X': x, 'Y': y}
52+
self.outputs = {'Out': np.minimum(self.inputs['X'], self.inputs['Y'])}
53+
54+
4655
class TestElementwiseMaxOp_Vector(TestElementwiseOp):
4756
def setUp(self):
4857
self.op_type = "elementwise_min"

python/paddle/v2/fluid/tests/test_elementwise_mul_op.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
22
#
3-
#Licensed under the Apache License, Version 2.0 (the "License");
4-
#you may not use this file except in compliance with the License.
5-
#You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
#Unless required by applicable law or agreed to in writing, software
10-
#distributed under the License is distributed on an "AS IS" BASIS,
11-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
#See the License for the specific language governing permissions and
13-
#limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414
import unittest
1515
import numpy as np
1616
from op_test import OpTest
@@ -38,6 +38,16 @@ def test_check_grad_ingore_y(self):
3838
self.check_grad(['X'], 'Out', no_grad_set=set('Y'))
3939

4040

41+
class TestElementwiseMulOp_scalar(ElementwiseMulOp):
42+
def setUp(self):
43+
self.op_type = "elementwise_mul"
44+
self.inputs = {
45+
'X': np.random.rand(2, 3, 4).astype(np.float32),
46+
'Y': np.random.rand(1).astype(np.float32)
47+
}
48+
self.outputs = {'Out': self.inputs['X'] * self.inputs['Y']}
49+
50+
4151
class TestElementwiseMulOp_Vector(ElementwiseMulOp):
4252
def setUp(self):
4353
self.op_type = "elementwise_mul"

python/paddle/v2/fluid/tests/test_elementwise_sub_op.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
22
#
3-
#Licensed under the Apache License, Version 2.0 (the "License");
4-
#you may not use this file except in compliance with the License.
5-
#You may obtain a copy of the License at
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
66
#
77
# http://www.apache.org/licenses/LICENSE-2.0
88
#
9-
#Unless required by applicable law or agreed to in writing, software
10-
#distributed under the License is distributed on an "AS IS" BASIS,
11-
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12-
#See the License for the specific language governing permissions and
13-
#limitations under the License.
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
1414
import unittest
1515
import numpy as np
1616
from op_test import OpTest
@@ -40,6 +40,16 @@ def test_check_grad_ingore_y(self):
4040
['X'], 'Out', max_relative_error=0.005, no_grad_set=set('Y'))
4141

4242

43+
class TestElementwiseSubOp_scalar(TestElementwiseOp):
44+
def setUp(self):
45+
self.op_type = "elementwise_sub"
46+
self.inputs = {
47+
'X': np.random.rand(2, 3, 4).astype(np.float32),
48+
'Y': np.random.rand(1).astype(np.float32)
49+
}
50+
self.outputs = {'Out': self.inputs['X'] - self.inputs['Y']}
51+
52+
4353
class TestElementwiseSubOp_Vector(TestElementwiseOp):
4454
def setUp(self):
4555
self.op_type = "elementwise_sub"

0 commit comments

Comments
 (0)