Skip to content

Commit c79d530

Browse files
author
Yancey
authored
Add split selected rows op (#7604)
* add split selected rows op * update comment * add grad check * registry cuda kernel * fix ci failed
1 parent 161bd4a commit c79d530

File tree

4 files changed

+319
-0
lines changed

4 files changed

+319
-0
lines changed
Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/split_selected_rows_op.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
class SplitSelectedRowsOpMaker : public framework::OpProtoAndCheckerMaker {
21+
public:
22+
SplitSelectedRowsOpMaker(OpProto *proto, OpAttrChecker *op_checker)
23+
: OpProtoAndCheckerMaker(proto, op_checker) {
24+
AddInput("X", "The input SelectedRows.");
25+
AddOutput("Out", "The outputs of input SelectedRows.").AsDuplicable();
26+
AddAttr<std::vector<int>>("rows_sections", "Rows section for output.")
27+
.SetDefault(std::vector<int>({}));
28+
AddAttr<std::vector<int>>("height_sections",
29+
"Height for each output SelectedRows.")
30+
.SetDefault(std::vector<int>({}));
31+
32+
AddComment(R"DOC(
33+
Split a SelectedRows with a specified rows section.
34+
height_sections is only needed when need to split the dims of the original tensor.
35+
36+
Example:
37+
Input:
38+
X.rows = {0, 7, 5}
39+
X.height = 12
40+
Attr:
41+
rows_sections = {1, 2}
42+
height_sections = {}
43+
Out:
44+
out0.rows = {0}
45+
out0.height = 12
46+
out1.rows = {7, 5}
47+
out2.height = 12
48+
49+
)DOC");
50+
}
51+
};
52+
53+
class SplitSelectedRowsOp : public framework::OperatorWithKernel {
54+
public:
55+
using framework::OperatorWithKernel::OperatorWithKernel;
56+
57+
void InferShape(framework::InferShapeContext *ctx) const override {
58+
PADDLE_ENFORCE(ctx->HasInput("X"), "SplitSelectedRowsOp must has input X.");
59+
PADDLE_ENFORCE(ctx->HasOutputs("Out"),
60+
"SplitSelectedRowsOp must has output Out.");
61+
62+
std::vector<int> height_sections =
63+
ctx->Attrs().Get<std::vector<int>>("height_sections");
64+
std::vector<int> rows_sections =
65+
ctx->Attrs().Get<std::vector<int>>("rows_sections");
66+
PADDLE_ENFORCE_EQ(
67+
rows_sections.size(), ctx->Outputs("Out").size(),
68+
"The size of rows section should be the same with Outputs size.");
69+
int64_t n = ctx->Outputs("Out").size();
70+
71+
std::vector<framework::DDim> outs_dims;
72+
outs_dims.reserve(n);
73+
74+
// make output dims
75+
for (int64_t i = 0; i < n; ++i) {
76+
auto dims = ctx->GetInputDim("X");
77+
if (height_sections.size()) {
78+
PADDLE_ENFORCE_EQ(
79+
height_sections.size(), static_cast<size_t>(n),
80+
"The size of height section should be the same with height"
81+
" section size.");
82+
dims[0] = height_sections[i];
83+
}
84+
outs_dims.push_back(dims);
85+
}
86+
ctx->SetOutputsDim("Out", outs_dims);
87+
}
88+
};
89+
90+
class SplitSelectedRowsGradMaker : public framework::SingleGradOpDescMaker {
91+
public:
92+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
93+
94+
protected:
95+
std::unique_ptr<framework::OpDesc> Apply() const override {
96+
auto *grad_op = new framework::OpDesc();
97+
grad_op->SetType("sum");
98+
grad_op->SetInput("X", OutputGrad("Out"));
99+
grad_op->SetOutput("Out", InputGrad("X"));
100+
grad_op->SetAttrMap(Attrs());
101+
return std::unique_ptr<framework::OpDesc>(grad_op);
102+
}
103+
};
104+
105+
} // namespace operators
106+
} // namespace paddle
107+
108+
namespace ops = paddle::operators;
109+
REGISTER_OPERATOR(split_selected_rows, ops::SplitSelectedRowsOp,
110+
ops::SplitSelectedRowsOpMaker,
111+
ops::SplitSelectedRowsGradMaker);
112+
REGISTER_OP_CPU_KERNEL(
113+
split_selected_rows,
114+
ops::SplitSelectedRowsOpKernel<paddle::platform::CPUPlace, float>);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/operators/split_selected_rows_op.h"
16+
namespace ops = paddle::operators;
17+
REGISTER_OP_CUDA_KERNEL(
18+
split_selected_rows,
19+
ops::SplitSelectedRowsOpKernel<paddle::platform::CUDADeviceContext, float>);
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
class SplitSelectedRowsOpKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
auto* x = ctx.Input<framework::SelectedRows>("X");
28+
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
29+
30+
auto rows_sections = ctx.Attr<std::vector<int>>("rows_sections");
31+
auto height_sections = ctx.Attr<std::vector<int>>("height_sections");
32+
33+
int64_t n = outs.size();
34+
int offset = 0;
35+
36+
for (int64_t i = 0; i < n; ++i) {
37+
framework::Vector<int64_t> out_rows;
38+
for (int64_t j = 0; j < rows_sections[i]; ++j) {
39+
out_rows.push_back(x->rows()[offset + j]);
40+
}
41+
42+
auto& out = outs[i];
43+
auto x_dims = x->GetCompleteDims();
44+
x_dims[0] = rows_sections[i];
45+
out->mutable_value()->mutable_data<T>(x_dims, ctx.GetPlace());
46+
framework::Copy(x->value().Slice(offset, rows_sections[i] + offset),
47+
x->place(), ctx.device_context(), out->mutable_value());
48+
outs[i]->set_rows(out_rows);
49+
if (height_sections.size()) {
50+
outs[i]->set_height(height_sections[i]);
51+
}
52+
offset += rows_sections[i];
53+
}
54+
}
55+
};
56+
57+
} // namespace operators
58+
} // namespace paddle
Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
2+
#
3+
#Licensed under the Apache License, Version 2.0 (the "License");
4+
#you may not use this file except in compliance with the License.
5+
#You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
#Unless required by applicable law or agreed to in writing, software
10+
#distributed under the License is distributed on an "AS IS" BASIS,
11+
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
#See the License for the specific language governing permissions and
13+
#limitations under the License.
14+
import unittest
15+
import paddle.v2.fluid.core as core
16+
import numpy as np
17+
from paddle.v2.fluid.op import Operator
18+
19+
20+
class TestSpliteSelectedRows(unittest.TestCase):
21+
def get_places(self):
22+
places = [core.CPUPlace()]
23+
if core.is_compile_gpu():
24+
places.append(core.CUDAPlace(0))
25+
return places
26+
27+
def test_check_output(self):
28+
for place in self.get_places():
29+
self.check_with_place(place)
30+
31+
def test_check_grad(self):
32+
for place in self.get_places():
33+
self.check_grad_with_place(place)
34+
35+
def check_with_place(self, place):
36+
scope = core.Scope()
37+
rows = [0, 5, 7, 4]
38+
height = 10
39+
row_numel = 2
40+
41+
# initialize input variable X
42+
x = scope.var('X').get_selected_rows()
43+
x.set_rows(rows)
44+
x.set_height(height)
45+
np_array = np.ones((len(rows), row_numel)).astype("float32")
46+
np_array[0, 0] = 2.0
47+
np_array[2, 1] = 4.0
48+
x_tensor = x.get_tensor()
49+
x_tensor.set(np_array, place)
50+
51+
rows_sections = [2, 2]
52+
height_sections = []
53+
54+
# initialize output variables [out0, out1]
55+
out0 = scope.var('out0').get_selected_rows()
56+
out1 = scope.var('out1').get_selected_rows()
57+
58+
# expected output selected rows
59+
expected_out0_rows = [0, 5]
60+
expected_out1_rows = [7, 4]
61+
expected_height = height
62+
63+
op = Operator(
64+
"split_selected_rows",
65+
X="X",
66+
Out=["out0", "out1"],
67+
rows_sections=rows_sections,
68+
height_sections=height_sections)
69+
70+
op.run(scope, place)
71+
72+
self.assertEqual(out0.rows(), expected_out0_rows)
73+
self.assertEqual(out1.rows(), expected_out1_rows)
74+
75+
self.assertEqual(out0.height(), expected_height)
76+
self.assertEqual(out1.height(), expected_height)
77+
78+
self.assertAlmostEqual(2.0, np.array(out0.get_tensor())[0, 0])
79+
self.assertAlmostEqual(4.0, np.array(out1.get_tensor())[0, 1])
80+
81+
def check_grad_with_place(self, place):
82+
scope = core.Scope()
83+
height = 10
84+
row_numel = 2
85+
86+
# attr
87+
rows_sections = [2, 2]
88+
height_sections = []
89+
90+
# initialize input variable X
91+
out0_grad = scope.var("out0@GRAD").get_selected_rows()
92+
rows0 = [0, 5]
93+
out0_grad.set_rows(rows0)
94+
out0_grad.set_height(height)
95+
out0_grad_tensor = out0_grad.get_tensor()
96+
np_array = np.ones((len(rows0), row_numel)).astype("float32")
97+
np_array[0, 0] = 2.0
98+
out0_grad_tensor.set(np_array, place)
99+
100+
out1_grad = scope.var("out1@GRAD").get_selected_rows()
101+
rows1 = [7, 5]
102+
out1_grad.set_rows(rows1)
103+
out1_grad.set_height(height)
104+
out1_grad_tensor = out1_grad.get_tensor()
105+
np_array = np.ones((len(rows1), row_numel)).astype("float32")
106+
np_array[0, 1] = 4.0
107+
out1_grad_tensor.set(np_array, place)
108+
109+
x_grad = scope.var("X@GRAD").get_selected_rows()
110+
111+
grad_op = Operator(
112+
"sum",
113+
X=["out0@GRAD", "out1@GRAD"],
114+
Out="X@GRAD",
115+
rows_sections=rows_sections,
116+
height_sections=height_sections)
117+
118+
grad_op.run(scope, place)
119+
120+
self.assertEqual(x_grad.rows(), rows0 + rows1)
121+
self.assertEqual(x_grad.height(), height)
122+
123+
self.assertAlmostEqual(2.0, np.array(x_grad.get_tensor())[0, 0])
124+
self.assertAlmostEqual(4.0, np.array(x_grad.get_tensor())[2, 1])
125+
126+
127+
if __name__ == "__main__":
128+
unittest.main()

0 commit comments

Comments
 (0)