Skip to content

Commit d655417

Browse files
authored
Merge pull request #9956 from typhoonzero/split_byref_op
Split byref op
2 parents f7fbef1 + ff0d934 commit d655417

File tree

8 files changed

+193
-21
lines changed

8 files changed

+193
-21
lines changed

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
8282
platform::CPUPlace cpu;
8383
auto& gpu_dev_ctx =
8484
static_cast<const platform::CUDADeviceContext&>(ctx);
85-
auto copy_size = tensor.memory_size();
85+
auto copy_size = tensor.numel() * framework::SizeOfType(tensor.type());
8686
payload = memory::Alloc(cpu, copy_size);
8787

8888
memory::Copy(cpu, payload,
@@ -99,7 +99,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
9999
} else {
100100
payload = tensor.data<void>();
101101
}
102-
payload_size = tensor.memory_size();
102+
payload_size = tensor.numel() * framework::SizeOfType(tensor.type());
103103
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
104104
} break;
105105
case framework::proto::VarType_Type_SELECTED_ROWS: {
@@ -118,7 +118,8 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
118118
platform::CPUPlace cpu;
119119
auto& gpu_dev_ctx =
120120
static_cast<const platform::CUDADeviceContext&>(ctx);
121-
auto copy_size = tensor->memory_size();
121+
auto copy_size =
122+
tensor->numel() * framework::SizeOfType(tensor->type());
122123
payload = memory::Alloc(cpu, copy_size);
123124
memory::Copy(cpu, payload,
124125
boost::get<platform::CUDAPlace>(tensor->place()),
@@ -133,7 +134,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
133134
} else {
134135
payload = slr->mutable_value()->data<void>();
135136
}
136-
payload_size = tensor->memory_size();
137+
payload_size = tensor->numel() * framework::SizeOfType(tensor->type());
137138
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
138139
} break;
139140
default:
Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/split_byref_op.h"
16+
#include "paddle/fluid/operators/split_op.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
using framework::Tensor;
21+
22+
class SplitByrefOp : public framework::OperatorWithKernel {
23+
public:
24+
using framework::OperatorWithKernel::OperatorWithKernel;
25+
26+
void InferShape(framework::InferShapeContext *ctx) const override {
27+
PADDLE_ENFORCE(ctx->HasInput("X"),
28+
"Input(X) of SplitOp should not be null.");
29+
PADDLE_ENFORCE_GE(ctx->Outputs("Out").size(), 1UL,
30+
"Outputs(Out) of SplitOp should not be empty.");
31+
auto in_dims = ctx->GetInputDim("X");
32+
auto outs_names = ctx->Outputs("Out");
33+
size_t num = static_cast<size_t>(ctx->Attrs().Get<int>("num"));
34+
std::vector<int> sections = static_cast<std::vector<int>>(
35+
ctx->Attrs().Get<std::vector<int>>("sections"));
36+
const size_t outs_number = outs_names.size();
37+
std::vector<framework::DDim> outs_dims;
38+
outs_dims.reserve(outs_number);
39+
40+
if (num > 0) {
41+
int64_t in_axis_dim = in_dims[0];
42+
PADDLE_ENFORCE_EQ(in_axis_dim % num, 0,
43+
"tensor split does not result"
44+
" in an equal division");
45+
size_t out_axis_dim = in_axis_dim / num;
46+
for (size_t i = 0; i < outs_number; ++i) {
47+
auto dim = in_dims;
48+
dim[0] = out_axis_dim;
49+
outs_dims.push_back(dim);
50+
}
51+
} else if (sections.size() > 0) {
52+
PADDLE_ENFORCE_EQ(sections.size(), outs_number,
53+
"tensor split sections size"
54+
"should be equal to output size.");
55+
for (size_t i = 0; i < outs_number; ++i) {
56+
auto dim = in_dims;
57+
dim[0] = sections[i];
58+
outs_dims.push_back(dim);
59+
}
60+
}
61+
ctx->SetOutputsDim("Out", outs_dims);
62+
}
63+
};
64+
65+
class SplitByrefOpMaker : public framework::OpProtoAndCheckerMaker {
66+
public:
67+
SplitByrefOpMaker(OpProto *proto, OpAttrChecker *op_checker)
68+
: OpProtoAndCheckerMaker(proto, op_checker) {
69+
AddInput("X", "(Tensor) Input tensor of the split operator.");
70+
AddOutput("Out", "(Tensor) Output tensors of the split operator.")
71+
.AsDuplicable();
72+
AddComment(R"DOC(
73+
SplitByref operator
74+
75+
Split source tensor to sevaral tensors by axis 0. No copy in this operator
76+
is performed, output tensor shares the same blocks of memory.
77+
)DOC");
78+
AddAttr<std::vector<int>>("sections",
79+
"(vector<int>) "
80+
"the length of each output along the "
81+
"specified axis.")
82+
.SetDefault(std::vector<int>{});
83+
AddAttr<int>("num",
84+
"(int, default 0)"
85+
"Number of sub-tensors. This must evenly divide "
86+
"Input.dims()[axis]")
87+
.SetDefault(0);
88+
}
89+
};
90+
91+
} // namespace operators
92+
} // namespace paddle
93+
94+
namespace ops = paddle::operators;
95+
// NOTE: concat op default axis must be 0!
96+
USE_CPU_ONLY_OP(concat);
97+
98+
REGISTER_OPERATOR(split_byref, ops::SplitByrefOp, ops::SplitByrefOpMaker,
99+
ops::SplitGradMaker);
100+
REGISTER_OP_CPU_KERNEL(
101+
split_byref, ops::SplitByrefOpKernel<paddle::platform::CPUPlace, float>);
Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#include "paddle/fluid/operators/split_byref_op.h"
16+
namespace ops = paddle::operators;
17+
REGISTER_OP_CUDA_KERNEL(
18+
split_byref,
19+
ops::SplitByrefOpKernel<paddle::platform::CUDADeviceContext, float>);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
17+
#include <vector>
18+
#include "paddle/fluid/framework/op_registry.h"
19+
20+
namespace paddle {
21+
namespace operators {
22+
23+
template <typename DeviceContext, typename T>
24+
class SplitByrefOpKernel : public framework::OpKernel<T> {
25+
public:
26+
void Compute(const framework::ExecutionContext& ctx) const override {
27+
auto* in = ctx.Input<framework::Tensor>("X");
28+
auto outs = ctx.MultiOutput<framework::Tensor>("Out");
29+
auto place = ctx.GetPlace();
30+
31+
size_t row_offset = 0;
32+
for (size_t i = 0; i < outs.size(); ++i) {
33+
// NOTE: no need to call mutable_data here to allocate memory.
34+
auto* out = outs[i];
35+
VLOG(3) << "spliting by ref: " << row_offset << " " << out->dims()[0];
36+
*out = std::move(in->Slice(row_offset, row_offset + out->dims()[0]));
37+
row_offset += out->dims()[0];
38+
}
39+
}
40+
};
41+
42+
} // namespace operators
43+
} // namespace paddle

paddle/fluid/operators/split_op.cc

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -108,21 +108,6 @@ This operator splits the input tensor into multiple sub-tensors.
108108
}
109109
};
110110

111-
class SplitGradMaker : public framework::SingleGradOpDescMaker {
112-
public:
113-
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
114-
115-
protected:
116-
std::unique_ptr<framework::OpDesc> Apply() const override {
117-
auto op = new framework::OpDesc();
118-
op->SetType("concat");
119-
op->SetInput("X", OutputGrad("Out"));
120-
op->SetOutput("Out", InputGrad("X"));
121-
op->SetAttrMap(Attrs());
122-
return std::unique_ptr<framework::OpDesc>(op);
123-
}
124-
};
125-
126111
} // namespace operators
127112
} // namespace paddle
128113

paddle/fluid/operators/split_op.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,5 +44,20 @@ class SplitOpKernel : public framework::OpKernel<T> {
4444
}
4545
};
4646

47+
class SplitGradMaker : public framework::SingleGradOpDescMaker {
48+
public:
49+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
50+
51+
protected:
52+
std::unique_ptr<framework::OpDesc> Apply() const override {
53+
auto op = new framework::OpDesc();
54+
op->SetType("concat");
55+
op->SetInput("X", OutputGrad("Out"));
56+
op->SetOutput("Out", InputGrad("X"));
57+
op->SetAttrMap(Attrs());
58+
return std::unique_ptr<framework::OpDesc>(op);
59+
}
60+
};
61+
4762
} // namespace operators
4863
} // namespace paddle

python/paddle/fluid/distribute_transpiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -825,7 +825,7 @@ def _append_split_op(self, program, gradblocks):
825825
for v in splited_vars:
826826
sections.append(v.shape[0])
827827
program.global_block().append_op(
828-
type="split",
828+
type="split_byref",
829829
inputs={"X": orig_var},
830830
outputs={"Out": splited_vars},
831831
attrs={"sections": sections} # assume split evenly

python/paddle/fluid/tests/unittests/test_split_op.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
class TestSplitOp(OpTest):
2121
def setUp(self):
22-
self.op_type = "split"
22+
self._set_op_type()
2323
axis = 1
2424
x = np.random.random((4, 5, 6)).astype('float32')
2525
out = np.split(x, [2, 3], axis)
@@ -28,12 +28,20 @@ def setUp(self):
2828
self.outputs = {'Out': [('out%d' % i, out[i]) \
2929
for i in xrange(len(out))]}
3030

31+
def _set_op_type(self):
32+
self.op_type = "split"
33+
3134
def test_check_output(self):
3235
self.check_output()
3336

3437
def test_check_grad(self):
3538
self.check_grad(['X'], ['out0', 'out1', 'out2'])
3639

3740

41+
class TestSplitByrefOp(OpTest):
42+
def _set_op_type(self):
43+
self.op_type = "split_byref"
44+
45+
3846
if __name__ == '__main__':
3947
unittest.main()

0 commit comments

Comments
 (0)