Skip to content

Commit 2a76b42

Browse files
authored
Merge pull request #5419 from reyoung/feature/shrink_memory_op
Feature/shrink memory op
2 parents d4d8f74 + 272a272 commit 2a76b42

File tree

5 files changed

+267
-38
lines changed

5 files changed

+267
-38
lines changed

paddle/operators/array_operator.h

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
15+
#pragma once
16+
#include "paddle/framework/lod_tensor_array.h"
17+
#include "paddle/framework/op_registry.h"
18+
19+
namespace paddle {
20+
namespace operators {
21+
class ArrayOp : public framework::OperatorBase {
22+
public:
23+
ArrayOp(const std::string &type, const framework::VariableNameMap &inputs,
24+
const framework::VariableNameMap &outputs,
25+
const framework::AttributeMap &attrs)
26+
: OperatorBase(type, inputs, outputs, attrs) {}
27+
28+
protected:
29+
size_t GetOffset(const framework::Scope &scope,
30+
const platform::DeviceContext &dev_ctx) const {
31+
auto *i = scope.FindVar(Input("I"));
32+
PADDLE_ENFORCE(i != nullptr, "I must be set");
33+
auto &i_tensor = i->Get<framework::LoDTensor>();
34+
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
35+
size_t offset;
36+
if (platform::is_gpu_place(i_tensor.place())) {
37+
// FIXME: Avoid copy from GPU to CPU
38+
framework::Tensor t;
39+
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
40+
dev_ctx.Wait();
41+
offset = static_cast<size_t>(*t.data<int64_t>());
42+
} else {
43+
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
44+
}
45+
return offset;
46+
}
47+
};
48+
49+
} // namespace operators
50+
} // namespace paddle
Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License. */
14+
#include "paddle/framework/lod_rank_table.h"
15+
#include "paddle/operators/array_operator.h"
16+
#include "paddle/operators/math/math_function.h"
17+
18+
namespace paddle {
19+
namespace operators {
20+
21+
class ShrinkRNNMemoryOp : public ArrayOp {
22+
public:
23+
ShrinkRNNMemoryOp(const std::string &type,
24+
const framework::VariableNameMap &inputs,
25+
const framework::VariableNameMap &outputs,
26+
const framework::AttributeMap &attrs)
27+
: ArrayOp(type, inputs, outputs, attrs) {}
28+
29+
void Run(const framework::Scope &scope,
30+
const platform::DeviceContext &dev_ctx) const override {
31+
auto *x_var = scope.FindVar(Input("X"));
32+
PADDLE_ENFORCE(x_var != nullptr, "Input X must be set");
33+
auto &x_tensor = x_var->Get<framework::LoDTensor>();
34+
size_t offset = this->GetOffset(scope, dev_ctx);
35+
auto *rank_table_var = scope.FindVar(Input("RankTable"));
36+
PADDLE_ENFORCE(rank_table_var != nullptr, "RankTable must be set");
37+
auto &rank_table = rank_table_var->Get<framework::LoDRankTable>();
38+
39+
auto &rank_items = rank_table.items();
40+
int dst_num_rows =
41+
std::lower_bound(rank_items.begin(), rank_items.end(), offset,
42+
[](const framework::LoDRankTable::TableItem &a,
43+
size_t b) { return a.length > b; }) -
44+
rank_items.begin();
45+
46+
auto *out_var = scope.FindVar(Output("Out"));
47+
PADDLE_ENFORCE(out_var != nullptr, "Output Out must be set");
48+
auto &out_tensor = *out_var->GetMutable<framework::LoDTensor>();
49+
if (dst_num_rows != 0) {
50+
out_tensor.ShareDataWith(x_tensor.Slice(0, dst_num_rows));
51+
}
52+
}
53+
};
54+
55+
class ShrinkRNNMemoryOpProtoMaker : public framework::OpProtoAndCheckerMaker {
56+
public:
57+
ShrinkRNNMemoryOpProtoMaker(framework::OpProto *proto,
58+
framework::OpAttrChecker *op_checker)
59+
: OpProtoAndCheckerMaker(proto, op_checker) {
60+
AddInput("X", "");
61+
AddInput("RankTable", "");
62+
AddInput("I", "");
63+
AddOutput("Out", "");
64+
AddComment("");
65+
}
66+
};
67+
68+
class ShrinkRNNMemoryInferShape : public framework::InferShapeBase {
69+
public:
70+
void operator()(framework::InferShapeContext *context) const override {
71+
PADDLE_ENFORCE(context->HasInput("X"));
72+
PADDLE_ENFORCE(context->HasInput("I"));
73+
PADDLE_ENFORCE(context->HasInput("RankTable"));
74+
context->SetOutputDim("Out", context->GetInputDim("X"));
75+
}
76+
};
77+
78+
class ShrinkRNNMemoryGradOp : public ArrayOp {
79+
public:
80+
ShrinkRNNMemoryGradOp(const std::string &type,
81+
const framework::VariableNameMap &inputs,
82+
const framework::VariableNameMap &outputs,
83+
const framework::AttributeMap &attrs)
84+
: ArrayOp(type, inputs, outputs, attrs) {}
85+
86+
void Run(const framework::Scope &scope,
87+
const platform::DeviceContext &dev_ctx) const override {
88+
auto *dout_var = scope.FindVar(Input(framework::GradVarName("Out")));
89+
auto *dx_var = scope.FindVar(Output(framework::GradVarName("X")));
90+
PADDLE_ENFORCE(dx_var != nullptr, "Input Gradient should not be nullptr");
91+
auto *x_var = scope.FindVar(Input("X"));
92+
PADDLE_ENFORCE(x_var != nullptr);
93+
94+
auto &x_tensor = x_var->Get<framework::LoDTensor>();
95+
auto &dx_tensor = *dx_var->GetMutable<framework::LoDTensor>();
96+
dx_tensor.Resize(x_tensor.dims());
97+
dx_tensor.mutable_data(x_tensor.place(), x_tensor.type());
98+
99+
if (dout_var == nullptr) { // dx_tensor fill zero
100+
math::set_constant(dev_ctx, &dx_tensor, 0.0f);
101+
} else {
102+
auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
103+
auto height = dout_tensor.dims()[0];
104+
dx_tensor.Slice(0, static_cast<int>(height))
105+
.CopyFrom(dout_tensor, dout_tensor.place(), dev_ctx);
106+
if (dx_tensor.dims()[0] < height) {
107+
auto rest_tensor = dx_tensor.Slice(
108+
static_cast<int>(height), static_cast<int>(dout_tensor.dims()[0]));
109+
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
110+
}
111+
}
112+
}
113+
};
114+
115+
class ShrinkRNNMemoryGradInferShape : public framework::InferShapeBase {
116+
public:
117+
void operator()(framework::InferShapeContext *context) const override {
118+
PADDLE_ENFORCE(context->HasInput("X"));
119+
PADDLE_ENFORCE(context->HasOutput(framework::GradVarName("X")));
120+
context->SetOutputDim(framework::GradVarName("X"),
121+
context->GetInputDim("X"));
122+
}
123+
};
124+
125+
class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
126+
public:
127+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
128+
129+
protected:
130+
std::unique_ptr<framework::OpDescBind> Apply() const override {
131+
auto *op = new framework::OpDescBind();
132+
op->SetType("shrink_rnn_memory_grad");
133+
op->SetInput("X", Input("X"));
134+
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
135+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
136+
op->SetAttrMap(Attrs());
137+
return std::unique_ptr<framework::OpDescBind>(op);
138+
}
139+
};
140+
141+
} // namespace operators
142+
} // namespace paddle
143+
144+
namespace ops = paddle::operators;
145+
REGISTER_OPERATOR(shrink_rnn_memory, ops::ShrinkRNNMemoryOp,
146+
ops::ShrinkRNNMemoryInferShape,
147+
ops::ShrinkRNNMemoryOpProtoMaker, ops::ShrinkRNNGradOpMaker);
148+
REGISTER_OPERATOR(shrink_rnn_memory_grad, ops::ShrinkRNNMemoryGradOp,
149+
ops::ShrinkRNNMemoryGradInferShape);

paddle/operators/tensor_array_read_write_op.cc

Lines changed: 5 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,48 +11,18 @@
1111
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
14-
#include "paddle/framework/lod_tensor_array.h"
15-
#include "paddle/framework/op_registry.h"
14+
#include "paddle/operators/array_operator.h"
1615

1716
namespace paddle {
1817
namespace operators {
19-
class ArrayOpBase : public framework::OperatorBase {
20-
public:
21-
ArrayOpBase(const std::string &type, const framework::VariableNameMap &inputs,
22-
const framework::VariableNameMap &outputs,
23-
const framework::AttributeMap &attrs)
24-
: OperatorBase(type, inputs, outputs, attrs) {}
25-
void Run(const framework::Scope &scope,
26-
const platform::DeviceContext &dev_ctx) const override {}
27-
28-
protected:
29-
size_t GetOffset(const framework::Scope &scope,
30-
const platform::DeviceContext &dev_ctx) const {
31-
auto *i = scope.FindVar(Input("I"));
32-
PADDLE_ENFORCE(i != nullptr, "I must be set");
33-
auto &i_tensor = i->Get<framework::LoDTensor>();
34-
PADDLE_ENFORCE_EQ(i_tensor.numel(), 1);
35-
size_t offset;
36-
if (platform::is_gpu_place(i_tensor.place())) {
37-
// FIXME: Avoid copy from GPU to CPU
38-
framework::Tensor t;
39-
t.CopyFrom(i_tensor, platform::CPUPlace(), dev_ctx);
40-
dev_ctx.Wait();
41-
offset = static_cast<size_t>(*t.data<int64_t>());
42-
} else {
43-
offset = static_cast<size_t>(*i_tensor.data<int64_t>());
44-
}
45-
return offset;
46-
}
47-
};
4818

49-
class WriteToArrayOp : public ArrayOpBase {
19+
class WriteToArrayOp : public ArrayOp {
5020
public:
5121
WriteToArrayOp(const std::string &type,
5222
const framework::VariableNameMap &inputs,
5323
const framework::VariableNameMap &outputs,
5424
const framework::AttributeMap &attrs)
55-
: ArrayOpBase(type, inputs, outputs, attrs) {}
25+
: ArrayOp(type, inputs, outputs, attrs) {}
5626

5727
void Run(const framework::Scope &scope,
5828
const platform::DeviceContext &dev_ctx) const override {
@@ -122,13 +92,13 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
12292
}
12393
};
12494

125-
class ReadFromArrayOp : public ArrayOpBase {
95+
class ReadFromArrayOp : public ArrayOp {
12696
public:
12797
ReadFromArrayOp(const std::string &type,
12898
const framework::VariableNameMap &inputs,
12999
const framework::VariableNameMap &outputs,
130100
const framework::AttributeMap &attrs)
131-
: ArrayOpBase(type, inputs, outputs, attrs) {}
101+
: ArrayOp(type, inputs, outputs, attrs) {}
132102
void Run(const framework::Scope &scope,
133103
const platform::DeviceContext &dev_ctx) const override {
134104
auto *x = scope.FindVar(Input("X"));

python/paddle/v2/framework/layers.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -891,13 +891,13 @@ def zeros(shape, dtype, main_program=None):
891891

892892
def increment(x, value=1.0, main_program=None):
893893
helper = LayerHelper("increment", **locals())
894-
tmp = helper.create_tmp_variable(dtype=x.data_type)
894+
out = helper.create_tmp_variable(dtype=x.data_type)
895895
helper.append_op(
896896
type='increment',
897897
inputs={'X': [x]},
898-
outputs={'Out': [tmp]},
898+
outputs={'Out': [out]},
899899
attrs={'step': value})
900-
return tmp
900+
return out
901901

902902

903903
def array_write(x, i, array=None, main_program=None):
@@ -928,3 +928,16 @@ def array_read(array, i, main_program=None):
928928
'I': [i]},
929929
outputs={'Out': [out]})
930930
return out
931+
932+
933+
def shrink_memory(x, i, table, main_program=None):
934+
helper = LayerHelper('shrink_memory', **locals())
935+
out = helper.create_tmp_variable(dtype=x.data_type)
936+
helper.append_op(
937+
type='shrink_rnn_memory',
938+
inputs={'X': [x],
939+
'I': [i],
940+
'RankTable': [table]},
941+
outputs={'Out': [out]},
942+
attrs={})
943+
return out
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
import unittest
2+
import paddle.v2.framework.core as core
3+
from paddle.v2.framework.executor import Executor
4+
import paddle.v2.framework.layers as layers
5+
from paddle.v2.framework.backward import append_backward_ops
6+
from paddle.v2.framework.framework import g_main_program
7+
import numpy
8+
9+
10+
class TestShrinkRNNMemory(unittest.TestCase):
11+
def test_shrink_rnn_memory(self):
12+
x = layers.data('x', shape=[100], data_type='float32')
13+
x.stop_gradient = False
14+
table = layers.lod_rank_table(x=x)
15+
i = layers.zeros(dtype='int64', shape=[1])
16+
mem1 = layers.shrink_memory(x=x, i=i, table=table)
17+
i = layers.increment(x=i)
18+
i.stop_gradient = True
19+
mem2 = layers.shrink_memory(x=mem1, i=i, table=table)
20+
i = layers.increment(x=i)
21+
i.stop_gradient = True
22+
mem3 = layers.shrink_memory(x=mem2, i=i, table=table)
23+
24+
cpu = core.CPUPlace()
25+
tensor = core.LoDTensor()
26+
tensor.set_lod([[0, 2, 5, 6]])
27+
tensor_np = numpy.random.random(size=(3, 100)).astype('float32')
28+
tensor.set(tensor_np, cpu)
29+
exe = Executor(cpu)
30+
outs = map(numpy.array,
31+
exe.run(feed={'x': tensor}, fetch_list=[mem1, mem2, mem3]))
32+
self.assertTrue(numpy.allclose(tensor_np[0:3], outs[0]))
33+
self.assertTrue(numpy.allclose(tensor_np[0:2], outs[1]))
34+
self.assertTrue(numpy.allclose(tensor_np[0:1], outs[2]))
35+
36+
mem3_mean = layers.mean(x=mem3)
37+
append_backward_ops(loss=mem3_mean)
38+
x_grad = map(numpy.array,
39+
exe.run(feed={'x': tensor},
40+
fetch_list=[
41+
g_main_program.global_block().var('x@GRAD')
42+
]))[0]
43+
self.assertAlmostEqual(1.0, x_grad.sum(), delta=0.1)
44+
45+
46+
if __name__ == '__main__':
47+
unittest.main()

0 commit comments

Comments
 (0)