Skip to content

Commit 834face

Browse files
authored
fix the CudaPinMemory bug for the compare op
fix the CudaPinMemory bug for the equal op, add the branch for the CudaPinnedMemory
1 parent a8e355a commit 834face

File tree

2 files changed

+42
-2
lines changed

2 files changed

+42
-2
lines changed

paddle/fluid/operators/controlflow/compare_op.cc

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -111,8 +111,16 @@ class CompareOp : public framework::OperatorWithKernel {
111111
framework::OpKernelType kt = OperatorWithKernel::GetExpectedKernelType(ctx);
112112
// CompareOp kernel's device type is decided by input tensor place
113113
bool force_cpu = ctx.Attr<bool>("force_cpu");
114-
kt.place_ = force_cpu ? platform::CPUPlace()
115-
: ctx.Input<framework::LoDTensor>("X")->place();
114+
if (force_cpu) {
115+
kt.place_ = platform::CPUPlace();
116+
} else {
117+
if (ctx.Input<framework::LoDTensor>("X")->place().type() !=
118+
typeid(platform::CUDAPinnedPlace)) {
119+
kt.place_ = ctx.Input<framework::LoDTensor>("X")->place();
120+
} else {
121+
kt.place_ = ctx.GetPlace();
122+
}
123+
}
116124
return kt;
117125
}
118126
};

python/paddle/fluid/tests/unittests/test_compare_op.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ def test_output(self):
3838
self.check_output()
3939

4040
def test_errors(self):
41+
paddle.enable_static()
4142
with program_guard(Program(), Program()):
4243
x = fluid.layers.data(name='x', shape=[2], dtype='int32')
4344
y = fluid.layers.data(name='y', shape=[2], dtype='int32')
@@ -80,6 +81,7 @@ def setUp(self):
8081
self.place = paddle.CUDAPlace(0)
8182

8283
def test_api(self):
84+
paddle.enable_static()
8385
with program_guard(Program(), Program()):
8486
x = fluid.data(name='x', shape=[4], dtype='int64')
8587
y = fluid.data(name='y', shape=[4], dtype='int64')
@@ -92,6 +94,7 @@ def test_api(self):
9294
self.assertEqual((res == self.real_result).all(), True)
9395

9496
def test_broadcast_api_1(self):
97+
paddle.enable_static()
9598
with program_guard(Program(), Program()):
9699
x = paddle.static.data(
97100
name='x', shape=[1, 2, 1, 3], dtype='int32')
@@ -108,6 +111,7 @@ def test_broadcast_api_1(self):
108111
self.assertEqual((res == real_result).all(), True)
109112

110113
def test_attr_name(self):
114+
paddle.enable_static()
111115
with program_guard(Program(), Program()):
112116
x = fluid.layers.data(name='x', shape=[4], dtype='int32')
113117
y = fluid.layers.data(name='y', shape=[4], dtype='int32')
@@ -130,6 +134,7 @@ def test_attr_name(self):
130134

131135
class TestCompareOpError(unittest.TestCase):
132136
def test_errors(self):
137+
paddle.enable_static()
133138
with program_guard(Program(), Program()):
134139
# The input x and y of compare_op must be Variable.
135140
x = fluid.layers.data(name='x', shape=[1], dtype="float32")
@@ -140,6 +145,7 @@ def test_errors(self):
140145

141146
class API_TestElementwise_Equal(unittest.TestCase):
142147
def test_api(self):
148+
paddle.enable_static()
143149
with fluid.program_guard(fluid.Program(), fluid.Program()):
144150
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
145151
limit = fluid.layers.assign(np.array([3, 2], dtype="int32"))
@@ -159,5 +165,31 @@ def test_api(self):
159165
self.assertEqual((res == np.array([True, True])).all(), True)
160166

161167

168+
class TestCompareOpPlace(unittest.TestCase):
169+
def test_place_1(self):
170+
paddle.enable_static()
171+
place = paddle.CPUPlace()
172+
if core.is_compiled_with_cuda():
173+
place = paddle.CUDAPlace(0)
174+
label = fluid.layers.assign(np.array([3, 3], dtype="int32"))
175+
limit = fluid.layers.assign(np.array([3, 2], dtype="int32"))
176+
out = fluid.layers.less_than(label, limit, force_cpu=True)
177+
exe = fluid.Executor(place)
178+
res, = exe.run(fetch_list=[out])
179+
self.assertEqual((res == np.array([False, False])).all(), True)
180+
181+
def test_place_2(self):
182+
place = paddle.CPUPlace()
183+
data_place = place
184+
if core.is_compiled_with_cuda():
185+
place = paddle.CUDAPlace(0)
186+
data_place = paddle.CUDAPinnedPlace()
187+
paddle.disable_static(place)
188+
data = np.array([9], dtype="int64")
189+
data_tensor = paddle.to_tensor(data, place=data_place)
190+
result = data_tensor == 0
191+
self.assertEqual((result.numpy() == np.array([False])).all(), True)
192+
193+
162194
if __name__ == '__main__':
163195
unittest.main()

0 commit comments

Comments
 (0)