Skip to content

Commit f45b0b0

Browse files
authored
Merge pull request #7688 from reyoung/feature/python_overload_math_operators
Add math operator patches
2 parents cb17dd2 + 87b424e commit f45b0b0

File tree

4 files changed

+338
-0
lines changed

4 files changed

+338
-0
lines changed

python/paddle/v2/fluid/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
from memory_optimization_transpiler import memory_optimize
3838

3939
Tensor = LoDTensor
40+
4041
__all__ = framework.__all__ + executor.__all__ + [
4142
'io',
4243
'initializer',
@@ -94,4 +95,5 @@ def __bootstrap__():
9495
core.init_devices()
9596

9697

98+
layers.monkey_patch_variable()
9799
__bootstrap__()

python/paddle/v2/fluid/layers/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@
2424
from control_flow import *
2525
import device
2626
from device import *
27+
import math_op_patch
28+
from math_op_patch import *
2729

2830
__all__ = []
2931
__all__ += nn.__all__
@@ -32,3 +34,4 @@
3234
__all__ += control_flow.__all__
3335
__all__ += ops.__all__
3436
__all__ += device.__all__
37+
__all__ += math_op_patch.__all__
Lines changed: 152 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from ..framework import Variable, unique_name
16+
from ..registry import OpProtoHolder
17+
18+
__all__ = ['monkey_patch_variable']
19+
20+
21+
def monkey_patch_variable():
22+
def unique_tmp_name():
23+
return unique_name("tmp")
24+
25+
def safe_get_dtype(var):
26+
try:
27+
dtype = var.dtype
28+
except:
29+
raise ValueError("Cannot get data type from %s", var.name)
30+
return dtype
31+
32+
def create_tensor(block, value, dtype, shape):
33+
value = float(value)
34+
tmp_name = unique_tmp_name()
35+
var = block.create_var(name=tmp_name, shape=shape, dtype=dtype)
36+
block.append_op(
37+
type="fill_constant",
38+
outputs={'Out': [var]},
39+
attrs={'dtype': var.dtype,
40+
'shape': shape,
41+
'value': value})
42+
return var
43+
44+
def create_scalar(block, value, dtype):
45+
return create_tensor(block, value, dtype, shape=[1])
46+
47+
def create_tensor_with_batchsize(ref_var, value, dtype):
48+
assert isinstance(ref_var, Variable)
49+
value = float(value)
50+
tmp_name = unique_tmp_name()
51+
var = ref_var.block.create_var(name=tmp_name, dtype=dtype)
52+
ref_var.block.append_op(
53+
type='fill_constant_batch_size_like',
54+
outputs={'Out': [var]},
55+
inputs={'Input': [ref_var]},
56+
attrs={'shape': ref_var.shape,
57+
'value': value})
58+
return var
59+
60+
def astype(self, dtype):
61+
"""
62+
Cast a variable to a specified data type.
63+
NOTE: The variable must be a Tensor
64+
Args:
65+
self(Variable): The source variable
66+
dtype: The target dtype
67+
68+
Returns:
69+
Variable with new dtype
70+
"""
71+
tmp_name = unique_tmp_name()
72+
out = self.block.create_var(name=tmp_name, dtype=dtype)
73+
self.block.append_op(
74+
type="cast",
75+
inputs={"X": [self]},
76+
outputs={"Out": [out]},
77+
attrs={"in_dtype": self.dtype,
78+
"out_dtype": out.dtype})
79+
return out
80+
81+
def _elemwise_method_creator_(method_name, op_type, reverse=False):
82+
def __impl__(self, other_var):
83+
lhs_dtype = safe_get_dtype(self)
84+
85+
if not isinstance(other_var, Variable):
86+
if reverse:
87+
has_batch_size = False
88+
for elem in self.shape:
89+
if elem < 0:
90+
has_batch_size = True
91+
break
92+
if not has_batch_size:
93+
other_var = create_tensor(
94+
self.block,
95+
other_var,
96+
dtype=lhs_dtype,
97+
shape=self.shape)
98+
else:
99+
other_var = create_tensor_with_batchsize(
100+
self, other_var, lhs_dtype)
101+
else:
102+
# add fill_op to self.block
103+
other_var = create_scalar(
104+
self.block, value=other_var, dtype=lhs_dtype)
105+
106+
rhs_dtype = safe_get_dtype(other_var)
107+
if lhs_dtype != rhs_dtype:
108+
other_var = astype(other_var, lhs_dtype)
109+
if reverse:
110+
tmp = self
111+
self = other_var
112+
other_var = tmp
113+
114+
tmp_name = unique_tmp_name()
115+
out = self.block.create_var(name=tmp_name, dtype=lhs_dtype)
116+
self.block.append_op(
117+
type=op_type,
118+
inputs={'X': [self],
119+
'Y': [other_var]},
120+
outputs={'Out': out})
121+
return out
122+
123+
comment = OpProtoHolder.instance().get_op_proto(op_type).comment
124+
125+
__impl__.__doc__ = """
126+
{0}
127+
Args:
128+
self(Variable): left hand variable
129+
other_var(Variable|float|int): right hand variable
130+
131+
Returns:
132+
Variable
133+
""".format(comment)
134+
__impl__.__name__ = method_name
135+
return __impl__
136+
137+
# inject methods
138+
for method_name, op_type, reverse in (
139+
("__add__", "elementwise_add", False),
140+
# a+b == b+a. Do not need to reverse explicitly
141+
("__radd__", "elementwise_add", False),
142+
("__sub__", "elementwise_sub", False),
143+
("__rsub__", "elementwise_sub", True),
144+
("__mul__", "elementwise_mul", False),
145+
# a*b == b*a. Do not need to reverse explicitly
146+
("__rmul__", "elementwise_mul", False),
147+
("__div__", "elementwise_div", False),
148+
("__rdiv__", "elementwise_div", True)):
149+
setattr(Variable, method_name,
150+
_elemwise_method_creator_(method_name, op_type, reverse))
151+
152+
Variable.astype = astype
Lines changed: 181 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,181 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
import decorators
17+
import paddle.v2.fluid as fluid
18+
import numpy
19+
20+
21+
class TestMathOpPatches(unittest.TestCase):
22+
@decorators.prog_scope()
23+
def test_add_scalar(self):
24+
a = fluid.layers.data(name="a", shape=[1])
25+
b = a + 10
26+
place = fluid.CPUPlace()
27+
exe = fluid.Executor(place)
28+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
29+
b_np = exe.run(fluid.default_main_program(),
30+
feed={"a": a_np},
31+
fetch_list=[b])
32+
self.assertTrue(numpy.allclose(a_np + 10, b_np))
33+
34+
@decorators.prog_scope()
35+
def test_radd_scalar(self):
36+
a = fluid.layers.data(name="a", shape=[1])
37+
b = 10 + a
38+
place = fluid.CPUPlace()
39+
exe = fluid.Executor(place)
40+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
41+
b_np = exe.run(fluid.default_main_program(),
42+
feed={"a": a_np},
43+
fetch_list=[b])
44+
self.assertTrue(numpy.allclose(a_np + 10, b_np))
45+
46+
@decorators.prog_scope()
47+
def test_sub_scalar(self):
48+
a = fluid.layers.data(name="a", shape=[1])
49+
b = a - 10
50+
place = fluid.CPUPlace()
51+
exe = fluid.Executor(place)
52+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
53+
b_np = exe.run(fluid.default_main_program(),
54+
feed={"a": a_np},
55+
fetch_list=[b])
56+
self.assertTrue(numpy.allclose(a_np - 10, b_np))
57+
58+
@decorators.prog_scope()
59+
def test_radd_scalar(self):
60+
a = fluid.layers.data(name="a", shape=[1])
61+
b = 10 - a
62+
place = fluid.CPUPlace()
63+
exe = fluid.Executor(place)
64+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
65+
b_np = exe.run(fluid.default_main_program(),
66+
feed={"a": a_np},
67+
fetch_list=[b])
68+
self.assertTrue(numpy.allclose(10 - a_np, b_np))
69+
70+
@decorators.prog_scope()
71+
def test_mul_scalar(self):
72+
a = fluid.layers.data(name="a", shape=[1])
73+
b = a * 10
74+
place = fluid.CPUPlace()
75+
exe = fluid.Executor(place)
76+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
77+
b_np = exe.run(fluid.default_main_program(),
78+
feed={"a": a_np},
79+
fetch_list=[b])
80+
self.assertTrue(numpy.allclose(a_np * 10, b_np))
81+
82+
@decorators.prog_scope()
83+
def test_rmul_scalar(self):
84+
a = fluid.layers.data(name="a", shape=[1])
85+
b = 10 * a
86+
place = fluid.CPUPlace()
87+
exe = fluid.Executor(place)
88+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
89+
b_np = exe.run(fluid.default_main_program(),
90+
feed={"a": a_np},
91+
fetch_list=[b])
92+
self.assertTrue(numpy.allclose(10 * a_np, b_np))
93+
94+
@decorators.prog_scope()
95+
def test_div_scalar(self):
96+
a = fluid.layers.data(name="a", shape=[1])
97+
b = a / 10
98+
place = fluid.CPUPlace()
99+
exe = fluid.Executor(place)
100+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
101+
b_np = exe.run(fluid.default_main_program(),
102+
feed={"a": a_np},
103+
fetch_list=[b])
104+
self.assertTrue(numpy.allclose(a_np / 10, b_np))
105+
106+
@decorators.prog_scope()
107+
def test_rdiv_scalar(self):
108+
a = fluid.layers.data(name="a", shape=[1])
109+
b = 10 / a
110+
place = fluid.CPUPlace()
111+
exe = fluid.Executor(place)
112+
a_np = numpy.random.random(size=[10, 1]).astype('float32') + 1e-2
113+
114+
b_np = exe.run(fluid.default_main_program(),
115+
feed={"a": a_np},
116+
fetch_list=[b])
117+
self.assertTrue(numpy.allclose(10 / a_np, b_np))
118+
119+
@decorators.prog_scope()
120+
def test_div_two_tensor(self):
121+
a = fluid.layers.data(name="a", shape=[1])
122+
b = fluid.layers.data(name="b", shape=[1])
123+
c = a / b
124+
place = fluid.CPUPlace()
125+
exe = fluid.Executor(place)
126+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
127+
b_np = numpy.random.random(size=[10, 1]).astype('float32') + 1e-2
128+
c_np = exe.run(fluid.default_main_program(),
129+
feed={"a": a_np,
130+
'b': b_np},
131+
fetch_list=[c])
132+
self.assertTrue(numpy.allclose(a_np / b_np, c_np))
133+
134+
@decorators.prog_scope()
135+
def test_mul_two_tensor(self):
136+
a = fluid.layers.data(name="a", shape=[1])
137+
b = fluid.layers.data(name="b", shape=[1])
138+
c = a * b
139+
place = fluid.CPUPlace()
140+
exe = fluid.Executor(place)
141+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
142+
b_np = numpy.random.random(size=[10, 1]).astype('float32')
143+
c_np = exe.run(fluid.default_main_program(),
144+
feed={"a": a_np,
145+
'b': b_np},
146+
fetch_list=[c])
147+
self.assertTrue(numpy.allclose(a_np * b_np, c_np))
148+
149+
@decorators.prog_scope()
150+
def test_add_two_tensor(self):
151+
a = fluid.layers.data(name="a", shape=[1])
152+
b = fluid.layers.data(name="b", shape=[1])
153+
c = a + b
154+
place = fluid.CPUPlace()
155+
exe = fluid.Executor(place)
156+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
157+
b_np = numpy.random.random(size=[10, 1]).astype('float32')
158+
c_np = exe.run(fluid.default_main_program(),
159+
feed={"a": a_np,
160+
'b': b_np},
161+
fetch_list=[c])
162+
self.assertTrue(numpy.allclose(a_np + b_np, c_np))
163+
164+
@decorators.prog_scope()
165+
def test_sub_two_tensor(self):
166+
a = fluid.layers.data(name="a", shape=[1])
167+
b = fluid.layers.data(name="b", shape=[1])
168+
c = a - b
169+
place = fluid.CPUPlace()
170+
exe = fluid.Executor(place)
171+
a_np = numpy.random.random(size=[10, 1]).astype('float32')
172+
b_np = numpy.random.random(size=[10, 1]).astype('float32')
173+
c_np = exe.run(fluid.default_main_program(),
174+
feed={"a": a_np,
175+
'b': b_np},
176+
fetch_list=[c])
177+
self.assertTrue(numpy.allclose(a_np - b_np, c_np))
178+
179+
180+
if __name__ == '__main__':
181+
unittest.main()

0 commit comments

Comments
 (0)