Skip to content

Commit b8f557f

Browse files
committed
"add elementwise_add more type"
1 parent e34e129 commit b8f557f

File tree

7 files changed

+194
-58
lines changed

7 files changed

+194
-58
lines changed

paddle/operators/accuracy_op.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,9 @@ class AccuracyKernel : public framework::OpKernel<T> {
4545
auto* correct = ctx.Output<Tensor>("Correct");
4646
auto* total = ctx.Output<Tensor>("Total");
4747

48-
float* correct_data = correct->mutable_data<float>(ctx.GetPlace());
49-
int* accuracy_data = accuracy->mutable_data<int>(ctx.GetPlace());
48+
int* correct_data = correct->mutable_data<int>(ctx.GetPlace());
5049
int* total_data = total->mutable_data<int>(ctx.GetPlace());
50+
float* accuracy_data = accuracy->mutable_data<float>(ctx.GetPlace());
5151

5252
const int64_t* indices_data = indices->data<int64_t>();
5353
const int64_t* label_data = label->data<int64_t>();

paddle/operators/elementwise_add_op.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,13 @@ REGISTER_OP(elementwise_add, ops::ElementwiseOp, ops::ElementwiseAddOpMaker,
3434
elementwise_add_grad, ops::ElementwiseOpGrad);
3535
REGISTER_OP_CPU_KERNEL(
3636
elementwise_add,
37-
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>);
37+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, float>,
38+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, double>,
39+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int>,
40+
ops::ElementwiseAddKernel<paddle::platform::CPUPlace, int64_t>);
3841
REGISTER_OP_CPU_KERNEL(
3942
elementwise_add_grad,
40-
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>);
43+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, float>,
44+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, double>,
45+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int>,
46+
ops::ElementwiseAddGradKernel<paddle::platform::CPUPlace, int64_t>);
Lines changed: 151 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,18 @@
1-
from paddle.v2.framework.framework import Program, g_main_program, unique_name
2-
from paddle.v2.framework.layer_helper import LayerHelper
1+
from paddle.v2.framework.framework import Program, g_main_program, unique_name, Variable
32
import paddle.v2.framework.core as core
43

54

5+
def _clone_var_in_block_(block, var):
6+
assert isinstance(var, Variable)
7+
return block.create_var(
8+
name=var.name,
9+
shape=var.shape,
10+
dtype=var.data_type,
11+
type=var.type,
12+
lod_level=var.lod_level,
13+
persistable=True)
14+
15+
616
class Evaluator(object):
717
"""
818
Evalutor Base class.
@@ -13,33 +23,49 @@ class Evaluator(object):
1323
"""
1424

1525
def __init__(self, name, **kwargs):
26+
"""
27+
init the global states
28+
"""
1629
self._states = {}
17-
if kwargs.has_key("program"):
18-
self._program = kwargs.get("program")
30+
if kwargs.has_key("main_program"):
31+
self._main_program = kwargs.get("main_program")
32+
else:
33+
self._main_program = g_main_program
34+
if kwargs.has_key("eval_program"):
35+
self._eval_program = kwargs.get("eval_program")
1936
else:
20-
self._program = g_main_program
37+
self._eval_program = Program()
38+
39+
def _update_ops(self):
40+
"""
41+
append update ops to the global states
42+
"""
43+
raise NotImplementedError()
2144

2245
def reset(self, executor, program=None):
2346
"""
24-
Clear metric states at the begin of each pass/user specified batch
25-
"""
47+
Clear metric states at the begin of each pass/user specified batch
48+
"""
2649
if program == None:
2750
reset_program = Program()
2851
else:
2952
reset_program = program
3053
block = reset_program.global_block()
3154
for k, var in self._states.iteritems():
32-
zeros = block.create_var(dtype=var.data_type)
55+
g_var = _clone_var_in_block_(block, var)
56+
zeros = block.create_var(dtype="float32", persistable=True)
3357
block.append_op(
3458
type="fill_constant",
3559
outputs={"Out": [zeros]},
3660
attrs={
37-
"shape": var.shape,
38-
"value": 0,
61+
"shape": g_var.shape,
62+
"value": .0,
63+
"data_type": 5,
3964
})
4065
block.append_op(
41-
type="scale", inputs={"X": zeros}, outputs={"Out": var})
42-
executor.run(reset_program)
66+
type="scale", inputs={"X": zeros}, outputs={"Out": g_var})
67+
print reset_program
68+
executor.run(reset_program, fetch_list=self._states.values())
4369

4470
def eval(self, executor, program=None):
4571
"""
@@ -53,22 +79,25 @@ class Accuracy(Evaluator):
5379
Accuracy need two state variable Total, Correct
5480
"""
5581

56-
def __init__(self, input, label, k=1, **kwargs):
82+
def __init__(self, *args, **kwargs):
5783
super(Accuracy, self).__init__("accuracy", **kwargs)
58-
block = self._program.global_block()
84+
# block = self._eval_program.global_block()
85+
block = self._main_program.global_block()
5986
g_total = block.create_var(
6087
name=unique_name("Total"),
6188
persistable=True,
6289
dtype="int64",
6390
shape=[1])
64-
g_correct = helper.create_global_variable(
91+
g_correct = block.create_var(
6592
name=unique_name("Correct"),
6693
persistable=True,
6794
dtype="int64",
6895
shape=[1])
6996
self._states["Total"] = g_total
7097
self._states["Correct"] = g_correct
7198

99+
def _update_ops(self, input, label, k=1, **kwargs):
100+
block = self._main_program.global_block()
72101
topk_out = block.create_var(dtype=input.data_type)
73102
topk_indices = block.create_var(dtype="int64")
74103
block.append_op(
@@ -77,8 +106,9 @@ def __init__(self, input, label, k=1, **kwargs):
77106
outputs={"Out": [topk_out],
78107
"Indices": [topk_indices]},
79108
attrs={"k": k})
80-
acc_out_dtype = kwargs.get("out_dtype", "float32")
81-
acc_out = block.create_var(dtype=acc_out_dtype)
109+
acc_out = block.create_var(dtype=kwargs.get("out_dtype", "float32"))
110+
correct = block.create_var(dtype="int64", persistable=True)
111+
total = block.create_var(dtype="int64", persistable=True)
82112
block.append_op(
83113
type="accuracy",
84114
inputs={
@@ -92,39 +122,121 @@ def __init__(self, input, label, k=1, **kwargs):
92122
"Total": [total],
93123
})
94124

125+
# block = self._eval_program.global_block()
126+
# e_correct = _clone_var_in_block_(block, correct)
127+
# e_total = _clone_var_in_block_(block, total)
128+
129+
# block.append_op(
130+
# type="sum",
131+
# inputs={"X": [self._states["Total"], total]},
132+
# outputs={"Out": [self._states["Total"]]})
133+
block.append_op(
134+
type="cast",
135+
inputs={"X": [self._states["Total"]]},
136+
outputs={"Out": [self._states["Total"]]},
137+
attrs={
138+
"in_data_type": 5,
139+
"out_data_type": 2,
140+
})
141+
block.append_op(
142+
type="cast",
143+
inputs={"X": [self._states["Correct"]]},
144+
outputs={"Out": [self._states["Correct"]]},
145+
attrs={
146+
"in_data_type": 5,
147+
"out_data_type": 2,
148+
})
149+
95150
block.append_op(
96-
type="sum",
97-
inputs={"X": [g_total, total]},
98-
outputs={"Out": [g_total]})
151+
type="elementwise_add",
152+
inputs={"X": [self._states["Total"]],
153+
"Y": [total]},
154+
outputs={"Out": [self._states["Total"]]})
99155
block.append_op(
100-
type="sum",
101-
inputs={"X": [g_correct, correct]},
102-
outputs={"Out": [g_total]})
156+
type="elementwise_add",
157+
inputs={"X": [self._states["Correct"]],
158+
"Y": [correct]},
159+
outputs={"Out": [self._states["Correct"]]})
160+
161+
# g_total = self._states["Total"]
162+
# print g_total
163+
# print total
164+
165+
# print "*" * 100
166+
# print g_total.block.program == total.block.program
167+
168+
# g_total = _clone_var_in_block_(block, self._states["Total"])
169+
# e_total = _clone_var_in_block_(block, total)
170+
171+
# block.append_op(
172+
# type="sum",
173+
# inputs={"X": [g_total, e_total]},
174+
# outputs={"Out": [g_total]})
175+
176+
# block.append_op(
177+
# type="sum",
178+
# inputs={"X": [self._states["Correct"], correct]},
179+
# outputs={"Out": [self._states["Correct"]]})
180+
# print self._main_program
103181
return acc_out
104182

105-
def eval(self, executor, program=None):
106-
if program == None:
107-
eval_program = Program()
108-
else:
109-
eval_program = program
110-
block = eval_program.global_block()
111-
eval_out = block.create_var(dtype=self._helper.input_dtype())
183+
def eval(self, executor):
184+
block = self._eval_program.global_block()
185+
eval_out = block.create_var(dtype=self._states["Total"].data_type)
186+
e_correct = _clone_var_in_block_(block, correct)
187+
e_total = _clone_var_in_block_(block, total)
188+
# block.append_op(
189+
# type="elementwise_div",
190+
# inputs={"X": self._states["Total"],
191+
# "Y": self._states["Correct"]},
192+
# outputs={"Out": eval_out})
112193
block.append_op(
113194
type="elementwise_div",
114-
inputs={"X": self._states["Total"],
115-
"Y": self._states["Correct"]},
195+
inputs={"X": e_total,
196+
"Y": e_correct},
116197
outputs={"Out": eval_out})
117-
return executor.run(eval_program, fetch_list=[eval_out])
198+
return executor.run(self._eval_program, fetch_list=[eval_out])
118199

119200

120-
# Demo for composing low level op to compute the F1 metric
121-
class F1(Evaluator):
122-
def __init__(self, input, label, **kwargs):
123-
super(F1, self).__init__("F1", **kwargs)
124-
g_tp = helper.create_global_variable(
201+
# Demo for composing low level ops to compute the F1 metric
202+
class FScore(Evaluator):
203+
def __init__(self, input, label, beta=1.0, **kwargs):
204+
super(F1, self).__init__("FScore", **kwargs)
205+
block = self._program.global_block()
206+
g_tp = block.create_var(
125207
name=unique_name("Tp"), persistable=True, dtype="int64", shape=[1])
126-
g_fp = helper.create_global_variable(
208+
g_fn = block.create_var(
209+
name=unique_name("Fn"), persistable=True, dtype="int64", shape=[1])
210+
g_fp = block.create_var(
127211
name=unique_name("Fp"), persistable=True, dtype="int64", shape=[1])
128212

129213
self._states["Tp"] = g_tp
130214
self._states["Fp"] = g_fp
215+
self._states["Fn"] = g_fn
216+
217+
def _update_ops(self):
218+
block = self._program.global_block()
219+
equal_out = block.create_var()
220+
block.append_op(
221+
type="equal",
222+
inputs={"X": [input],
223+
"Y": [label]},
224+
outputs={"Out": equal_out})
225+
226+
positive = block.create_var()
227+
block.append_op(
228+
type="sequence_pool",
229+
inputs={"X": [equal_out]},
230+
outputs={"Out": positive},
231+
attrs={"pooltype": "SUM"})
232+
batch = block.create_var(
233+
name=feed_var_name,
234+
type=core.VarDesc.VarType.FEED_MINIBATCH,
235+
persistable=True)
236+
237+
238+
# def register():
239+
accuracy = Accuracy
240+
# def accuracy(*args, **kwargs):
241+
# acc = Accuracy(**kwargs)
242+
# return acc._update_ops(*args, **kwargs)

python/paddle/v2/framework/framework.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -550,7 +550,7 @@ def __init__(self, block, shape, dtype, **kwargs):
550550
raise ValueError("Parameter shape should not be related with "
551551
"batch-size")
552552

553-
super(Parameter, self).__init__(
553+
Variable.__init__(
554554
self, block, persistable=True, shape=shape, dtype=dtype, **kwargs)
555555
self.trainable = kwargs.get('trainable', True)
556556

python/paddle/v2/framework/layers.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -263,15 +263,21 @@ def accuracy(input, label, k=1, **kwargs):
263263
"Indices": [topk_indices]},
264264
attrs={"k": k})
265265
acc_out_dtype = kwargs.get("out_dtype", "float32")
266-
acc_out = helper.create_tmp_variable(dtype=acc_out_dtype)
266+
acc_out = helper.create_tmp_variable(dtype="float32")
267+
correct = helper.create_tmp_variable(dtype="int64")
268+
total = helper.create_tmp_variable(dtype="int64")
267269
helper.append_op(
268270
type="accuracy",
269271
inputs={
270272
"Out": [topk_out],
271273
"Indices": [topk_indices],
272274
"Label": [label]
273275
},
274-
outputs={"Accuracy": [acc_out]})
276+
outputs={
277+
"Accuracy": [acc_out],
278+
"Correct": [correct],
279+
"Total": [total],
280+
})
275281
return acc_out
276282

277283

python/paddle/v2/framework/tests/test_accuracy_op.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ def setUp(self):
1919
break
2020
self.outputs = {
2121
'Accuracy': np.array([num_correct / float(n)]).astype("float32"),
22-
'Correct': np.array([num_correct]).astype("int32")
22+
'Correct': np.array([num_correct]).astype("int32"),
23+
'Total': np.array([n]).astype("int32")
2324
}
2425

2526
def test_check_output(self):
2627
self.check_output()
2728

2829

2930
if __name__ == '__main__':
30-
exit(0)
3131
unittest.main()

0 commit comments

Comments
 (0)