Skip to content

Commit 46c61b3

Browse files
committed
"add elementwise op support"
1 parent b8f557f commit 46c61b3

File tree

5 files changed

+71
-78
lines changed

5 files changed

+71
-78
lines changed

paddle/operators/elementwise_div_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,13 @@ REGISTER_OP(elementwise_div, ops::ElementwiseOp, ops::ElementwiseDivOpMaker,
3535
elementwise_div_grad, ops::ElementwiseOpGrad);
3636
REGISTER_OP_CPU_KERNEL(
3737
elementwise_div,
38-
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>);
38+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, float>,
39+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, double>,
40+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int>,
41+
ops::ElementwiseDivKernel<paddle::platform::CPUPlace, int64_t>);
3942
REGISTER_OP_CPU_KERNEL(
4043
elementwise_div_grad,
41-
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>);
44+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, float>,
45+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, double>,
46+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int>,
47+
ops::ElementwiseDivGradKernel<paddle::platform::CPUPlace, int64_t>);

paddle/operators/elementwise_mul_op.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,12 @@ REGISTER_OP(elementwise_mul, ops::ElementwiseOp, ops::ElementwiseMulOpMaker,
3737
REGISTER_OP_CPU_KERNEL(
3838
elementwise_mul,
3939
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, float>,
40-
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>);
40+
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, double>,
41+
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int>,
42+
ops::ElementwiseMulKernel<paddle::platform::CPUPlace, int64_t>);
4143
REGISTER_OP_CPU_KERNEL(
4244
elementwise_mul_grad,
4345
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, float>,
44-
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>);
46+
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, double>,
47+
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int>,
48+
ops::ElementwiseMulGradKernel<paddle::platform::CPUPlace, int64_t>);

paddle/operators/elementwise_sub_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_sub, ops::ElementwiseOp, ops::ElementwiseSubOpMaker,
3434
elementwise_sub_grad, ops::ElementwiseOpGrad);
3535
REGISTER_OP_CPU_KERNEL(
3636
elementwise_sub,
37-
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>);
37+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, float>,
38+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, double>,
39+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int>,
40+
ops::ElementwiseSubKernel<paddle::platform::CPUPlace, int64_t>);
3841
REGISTER_OP_CPU_KERNEL(
3942
elementwise_sub_grad,
40-
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>);
43+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, float>,
44+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, double>,
45+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int>,
46+
ops::ElementwiseSubGradKernel<paddle::platform::CPUPlace, int64_t>);

python/paddle/v2/framework/evaluator.py

Lines changed: 39 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from paddle.v2.framework.framework import Program, g_main_program, unique_name, Variable
23
import paddle.v2.framework.core as core
34

@@ -31,12 +32,8 @@ def __init__(self, name, **kwargs):
3132
self._main_program = kwargs.get("main_program")
3233
else:
3334
self._main_program = g_main_program
34-
if kwargs.has_key("eval_program"):
35-
self._eval_program = kwargs.get("eval_program")
36-
else:
37-
self._eval_program = Program()
3835

39-
def _update_ops(self):
36+
def _update_ops(self, *args, **kwargs):
4037
"""
4138
append update ops to the global states
4239
"""
@@ -64,13 +61,12 @@ def reset(self, executor, program=None):
6461
})
6562
block.append_op(
6663
type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
67-
print reset_program
6864
executor.run(reset_program, fetch_list=self._states.values())
6965

7066
def eval(self, executor, program=None):
7167
"""
72-
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
73-
"""
68+
Merge the mini-batch statistics to form the evaluation result for multiple mini-batches.
69+
"""
7470
raise NotImplementedError()
7571

7672

@@ -81,7 +77,6 @@ class Accuracy(Evaluator):
8177

8278
def __init__(self, *args, **kwargs):
8379
super(Accuracy, self).__init__("accuracy", **kwargs)
84-
# block = self._eval_program.global_block()
8580
block = self._main_program.global_block()
8681
g_total = block.create_var(
8782
name=unique_name("Total"),
@@ -122,21 +117,13 @@ def _update_ops(self, input, label, k=1, **kwargs):
122117
"Total": [total],
123118
})
124119

125-
# block = self._eval_program.global_block()
126-
# e_correct = _clone_var_in_block_(block, correct)
127-
# e_total = _clone_var_in_block_(block, total)
128-
129-
# block.append_op(
130-
# type="sum",
131-
# inputs={"X": [self._states["Total"], total]},
132-
# outputs={"Out": [self._states["Total"]]})
133120
block.append_op(
134121
type="cast",
135122
inputs={"X": [self._states["Total"]]},
136123
outputs={"Out": [self._states["Total"]]},
137124
attrs={
138-
"in_data_type": 5,
139-
"out_data_type": 2,
125+
"in_data_type": 5, # float32
126+
"out_data_type": 2, #int32
140127
})
141128
block.append_op(
142129
type="cast",
@@ -158,44 +145,40 @@ def _update_ops(self, input, label, k=1, **kwargs):
158145
"Y": [correct]},
159146
outputs={"Out": [self._states["Correct"]]})
160147

161-
# g_total = self._states["Total"]
162-
# print g_total
163-
# print total
164-
165-
# print "*" * 100
166-
# print g_total.block.program == total.block.program
167-
168-
# g_total = _clone_var_in_block_(block, self._states["Total"])
169-
# e_total = _clone_var_in_block_(block, total)
170-
171-
# block.append_op(
172-
# type="sum",
173-
# inputs={"X": [g_total, e_total]},
174-
# outputs={"Out": [g_total]})
175-
176-
# block.append_op(
177-
# type="sum",
178-
# inputs={"X": [self._states["Correct"], correct]},
179-
# outputs={"Out": [self._states["Correct"]]})
180-
# print self._main_program
181148
return acc_out
182149

183-
def eval(self, executor):
184-
block = self._eval_program.global_block()
150+
def eval(self, executor, program=None):
151+
if program != None:
152+
eval_program = program
153+
else:
154+
eval_program = Program()
155+
block = eval_program.global_block()
185156
eval_out = block.create_var(dtype=self._states["Total"].data_type)
186-
e_correct = _clone_var_in_block_(block, correct)
187-
e_total = _clone_var_in_block_(block, total)
188-
# block.append_op(
189-
# type="elementwise_div",
190-
# inputs={"X": self._states["Total"],
191-
# "Y": self._states["Correct"]},
192-
# outputs={"Out": eval_out})
157+
e_total = _clone_var_in_block_(block, self._states["Total"])
158+
e_correct = _clone_var_in_block_(block, self._states["Correct"])
159+
block.append_op(
160+
type="cast",
161+
inputs={"X": [e_total]},
162+
outputs={"Out": [e_total]},
163+
attrs={
164+
"in_data_type": 2, #int32
165+
"out_data_type": 5, #float32
166+
})
167+
block.append_op(
168+
type="cast",
169+
inputs={"X": [e_correct]},
170+
outputs={"Out": [e_correct]},
171+
attrs={
172+
"in_data_type": 2,
173+
"out_data_type": 5,
174+
})
193175
block.append_op(
194176
type="elementwise_div",
195-
inputs={"X": e_total,
196-
"Y": e_correct},
177+
inputs={"X": e_correct,
178+
"Y": e_total},
197179
outputs={"Out": eval_out})
198-
return executor.run(self._eval_program, fetch_list=[eval_out])
180+
out = executor.run(eval_program, fetch_list=[eval_out])
181+
return np.array(out[0])
199182

200183

201184
# Demo for composing low level ops to compute the F1 metric
@@ -235,8 +218,8 @@ def _update_ops(self):
235218
persistable=True)
236219

237220

238-
# def register():
239-
accuracy = Accuracy
240-
# def accuracy(*args, **kwargs):
241-
# acc = Accuracy(**kwargs)
242-
# return acc._update_ops(*args, **kwargs)
221+
# FIXME(dzh): add a decorator to call _update_ops automatically
222+
def accuracy(*args, **kwargs):
223+
cls = Accuracy(*args, **kwargs)
224+
out = cls._update_ops(*args, **kwargs)
225+
return cls, out

python/paddle/v2/framework/tests/test_recognize_digits_conv.py

Lines changed: 10 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,23 +55,14 @@
5555
main_program=main_program,
5656
startup_program=startup_program)
5757
avg_cost = layers.mean(x=cost, main_program=main_program)
58-
# accuracy = layers.accuracy(
59-
# input=predict,
60-
# label=label,
61-
# main_program=main_program,
62-
# startup_program=startup_program)
63-
# optimizer = optimizer.MomentumOptimizer(learning_rate=0.1 / 128.0,
64-
# momentum=0.9)
6558
optimizer = optimizer.AdamOptimizer(learning_rate=0.01, beta1=0.9, beta2=0.999)
6659
opts = optimizer.minimize(avg_cost, startup_program)
6760

68-
accuracy = evaluator.accuracy(
61+
accuracy, acc_out = evaluator.accuracy(
6962
input=predict,
7063
label=label,
7164
main_program=main_program,
7265
startup_program=startup_program)
73-
acc_out = accuracy._update_ops(
74-
input=predict, label=label, main_program=main_program)
7566

7667
BATCH_SIZE = 50
7768
PASS_NUM = 3
@@ -105,11 +96,14 @@
10596
fetch_list=[avg_cost, acc_out])
10697
loss = np.array(outs[0])
10798
acc = np.array(outs[1])
108-
# pass_acc = accuracy.eval(exe)
109-
# print pass_acc
110-
print loss, acc
99+
pass_acc = accuracy.eval(exe)
100+
print "pass id : ", pass_id, pass_acc
101+
# print loss, acc
102+
if loss < 10.0 and acc > 0.9:
103+
# if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
104+
exit(0)
105+
106+
pass_acc = accuracy.eval(exe)
107+
print "pass id : ", pass_id, pass_acc
111108

112-
# if loss < 10.0 and acc > 0.9:
113-
# # if avg cost less than 10.0 and accuracy is larger than 0.9, we think our code is good.
114-
# exit(0)
115109
exit(1)

0 commit comments

Comments
 (0)