Skip to content

Commit 669786b

Browse files
authored
refine square_error_cost layer (#5216)
* reimplement pow operator * add pow_grad operator * fix code style * fix build error * fix op_test bug * revert pow operator * add FIXME comment
1 parent afd1e84 commit 669786b

File tree

3 files changed

+9
-9
lines changed

3 files changed

+9
-9
lines changed

paddle/operators/activation_op.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,7 @@ struct ELUGradFunctor : public BaseActivationFunctor<T> {
547547
}
548548
};
549549

550+
// FIXME(qijun) https://github.com/PaddlePaddle/Paddle/issues/5198
550551
template <typename T>
551552
struct PowFunctor : public BaseActivationFunctor<T> {
552553
float factor;

python/paddle/v2/framework/layers.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -225,10 +225,7 @@ def square_error_cost(input, label, **kwargs):
225225

226226
square_out = helper.create_tmp_variable(dtype=input.data_type)
227227
helper.append_op(
228-
type='pow',
229-
inputs={'X': [minus_out]},
230-
outputs={'Y': [square_out]},
231-
attrs={'factor': 2.0})
228+
type='square', inputs={'X': [minus_out]}, outputs={'Y': [square_out]})
232229
return square_out
233230

234231

python/paddle/v2/framework/tests/op_test.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ def find_actual(target_name, fetch_list):
281281
type(sub_out))
282282
for sub_out_name, expect in sub_out:
283283
idx = find_actual(sub_out_name, fetch_list)
284-
actual_t = np.array(outs[idx])
284+
actual = outs[idx]
285+
actual_t = np.array(actual)
285286
expect_t = expect[0] \
286287
if isinstance(expect, tuple) else expect
287288
self.assertTrue(
@@ -291,19 +292,20 @@ def find_actual(target_name, fetch_list):
291292
str(place))
292293
if isinstance(expect, tuple):
293294
self.assertListEqual(
294-
actual_t.lod(), expect[1], "Output (" + sub_out_name
295-
+ ") has different lod at " + str(place))
295+
actual.lod(), expect[1], "Output (" + sub_out_name +
296+
") has different lod at " + str(place))
296297
else:
297298
idx = find_actual(out_name, fetch_list)
298-
actual_t = outs[idx]
299+
actual = outs[idx]
300+
actual_t = np.array(actual)
299301
expect = self.outputs[out_name]
300302
expect_t = expect[0] if isinstance(expect, tuple) else expect
301303
self.assertTrue(
302304
np.allclose(
303305
actual_t, expect_t, atol=atol),
304306
"Output (" + out_name + ") has diff at " + str(place))
305307
if isinstance(expect, tuple):
306-
self.assertListEqual(actual_t.lod(), expect[1],
308+
self.assertListEqual(actual.lod(), expect[1],
307309
"Output (" + out_name +
308310
") has different lod at " + str(place))
309311

0 commit comments

Comments
 (0)