Skip to content

Commit f7e9fe5

Browse files
authored
[Memory]More memory optimization policy (#8690)
* add memopt level * add opt level for image classification demo * clean code * add delete op * clean code * test machine translation demo * clean code * clean code * skip fill constant with force cpu * clean code * clean code * refine code * clean code * fix bug
1 parent 607eec3 commit f7e9fe5

File tree

13 files changed

+173
-33
lines changed

13 files changed

+173
-33
lines changed

paddle/fluid/framework/block_desc.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,14 @@ OpDesc *BlockDesc::PrependOp() {
135135
return ops_.front().get();
136136
}
137137

138+
OpDesc *BlockDesc::InsertOp(size_t index) {
139+
need_update_ = true;
140+
auto it = ops_.begin() + index;
141+
std::unique_ptr<OpDesc> new_op(new OpDesc(this));
142+
it = ops_.insert(it, std::move(new_op));
143+
return (*it).get();
144+
}
145+
138146
void BlockDesc::RemoveOp(size_t s, size_t e) {
139147
if (ops_.begin() + s == ops_.end() || ops_.begin() + e == ops_.end()) {
140148
return;

paddle/fluid/framework/block_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,8 @@ class BlockDesc {
8787

8888
OpDesc *PrependOp();
8989

90+
OpDesc *InsertOp(size_t index);
91+
9092
void RemoveOp(size_t s, size_t e);
9193

9294
std::vector<OpDesc *> AllOps() const;

paddle/fluid/framework/scope.cc

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ limitations under the License. */
1616

1717
#include <memory> // for unique_ptr
1818
#include <mutex> // for call_once
19+
#include <set>
1920
#include "glog/logging.h"
2021
#include "paddle/fluid/framework/threadpool.h"
2122
#include "paddle/fluid/string/printf.h"
@@ -102,6 +103,18 @@ void Scope::DeleteScope(Scope* scope) {
102103
}
103104
}
104105

106+
void Scope::EraseVars(std::vector<std::string>& var_names) {
107+
std::set<std::string> var_set(var_names.begin(), var_names.end());
108+
for (auto it = vars_.begin(); it != vars_.end();) {
109+
if (var_set.find(it->first) != var_set.end()) {
110+
delete it->second;
111+
it = vars_.erase(it);
112+
} else {
113+
++it;
114+
}
115+
}
116+
}
117+
105118
void Scope::Rename(const std::string& origin_name,
106119
const std::string& new_name) const {
107120
auto origin_it = vars_.find(origin_name);

paddle/fluid/framework/scope.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ class Scope {
5151
/// Create a variable with a scope-unique name.
5252
Variable* Var(std::string* name = nullptr);
5353

54+
void EraseVars(std::vector<std::string>& var_names);
55+
5456
/// Find a variable in the scope or any of its ancestors. Returns
5557
/// nullptr if cannot find.
5658
Variable* FindVar(const std::string& name) const;
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
2+
Licensed under the Apache License, Version 2.0 (the "License");
3+
you may not use this file except in compliance with the License.
4+
You may obtain a copy of the License at
5+
http://www.apache.org/licenses/LICENSE-2.0
6+
Unless required by applicable law or agreed to in writing, software
7+
distributed under the License is distributed on an "AS IS" BASIS,
8+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
See the License for the specific language governing permissions and
10+
limitations under the License. */
11+
12+
#include "paddle/fluid/framework/op_registry.h"
13+
#include "paddle/fluid/framework/operator.h"
14+
15+
namespace paddle {
16+
namespace operators {
17+
class DeleteVarOp : public framework::OperatorBase {
18+
public:
19+
DeleteVarOp(const std::string &type, const framework::VariableNameMap &inputs,
20+
const framework::VariableNameMap &outputs,
21+
const framework::AttributeMap &attrs)
22+
: OperatorBase(type, inputs, outputs, attrs) {}
23+
void RunImpl(const framework::Scope &scope,
24+
const platform::Place &place) const override {
25+
// get device context from pool
26+
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
27+
auto &dev_ctx = *pool.Get(place);
28+
dev_ctx.Wait();
29+
30+
auto delete_var_names = Inputs("X");
31+
const_cast<framework::Scope &>(scope).EraseVars(delete_var_names);
32+
}
33+
};
34+
35+
class DeleteVarOpInfoMaker : public framework::OpProtoAndCheckerMaker {
36+
public:
37+
DeleteVarOpInfoMaker(OpProto *proto, OpAttrChecker *op_checker)
38+
: OpProtoAndCheckerMaker(proto, op_checker) {
39+
AddInput("X", "The input of delete op").AsDuplicable();
40+
AddComment(R"DOC(
41+
Delete Operator.
42+
It should not be configured by users directly.
43+
)DOC");
44+
}
45+
};
46+
47+
} // namespace operators
48+
} // namespace paddle
49+
50+
REGISTER_OPERATOR(delete_var, paddle::operators::DeleteVarOp,
51+
paddle::framework::EmptyGradOpMaker,
52+
paddle::operators::DeleteVarOpInfoMaker);

paddle/fluid/pybind/protobuf.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -161,6 +161,8 @@ void BindBlockDesc(py::module &m) {
161161
py::return_value_policy::reference)
162162
.def("prepend_op", &BlockDesc::PrependOp,
163163
py::return_value_policy::reference)
164+
.def("insert_op", &BlockDesc::InsertOp,
165+
py::return_value_policy::reference)
164166
.def("remove_op", &BlockDesc::RemoveOp)
165167
.def("var",
166168
[](BlockDesc &self, py::bytes byte_name) {

python/paddle/fluid/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from concurrency import (Go, make_channel, channel_send, channel_recv,
3838
channel_close)
3939
import clip
40-
from memory_optimization_transpiler import memory_optimize
40+
from memory_optimization_transpiler import memory_optimize, release_memory
4141
import profiler
4242
import unique_name
4343

@@ -63,6 +63,7 @@
6363
'SimpleDistributeTranspiler',
6464
'DistributeTranspiler',
6565
'memory_optimize',
66+
'release_memory',
6667
'profiler',
6768
'unique_name',
6869
]

python/paddle/fluid/backward.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -457,7 +457,8 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
457457
"Out": [_append_grad_suffix_(loss.name)]
458458
}, {"shape": [1],
459459
"value": 1.0,
460-
"dtype": loss.dtype})
460+
"dtype": loss.dtype,
461+
"force_cpu": False})
461462
root_block.desc.append_op().copy_from(op_desc)
462463

463464
block_no_grad_set = set(map(_strip_grad_suffix_, no_grad_dict[0]))

python/paddle/fluid/memory_optimization_transpiler.py

Lines changed: 80 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@
2929
core.VarDesc.VarType.BOOL: 1
3030
}
3131

32-
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
32+
sub_block_ops = [
33+
"while", "while_grad", "parallel_do", "parallel_do_grad",
34+
"conditional_block", "conditional_block_grad"
35+
]
3336

3437
PRINT_LOG = False
3538

@@ -122,36 +125,82 @@ def _find_var(self, block_desc, var_name, is_forward):
122125
else:
123126
return block_desc.find_var_recursive(str(var_name))
124127

125-
def memory_optimize(self):
126-
def check_var_validity(block_desc, x, is_forward):
127-
if str(x) == "@EMPTY@":
128-
return False
129-
if not self._has_var(block_desc, x, is_forward):
130-
return False
131-
if self._find_var(block_desc, x, is_forward).persistable():
132-
return False
133-
if self._find_var(
134-
block_desc, x,
135-
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
136-
return False
137-
if x in self._skip_opt:
138-
return False
139-
if not self._find_var(block_desc, x, is_forward).shape():
140-
return False
141-
return True
128+
def _check_var_validity(self, block_desc, x, is_forward):
129+
if str(x) == "@EMPTY@":
130+
return False
131+
if not self._has_var(block_desc, x, is_forward):
132+
return False
133+
if self._find_var(block_desc, x, is_forward).persistable():
134+
return False
135+
if self._find_var(block_desc, x,
136+
is_forward).type() != core.VarDesc.VarType.LOD_TENSOR:
137+
return False
138+
if x in self._skip_opt:
139+
return False
140+
if not self._find_var(block_desc, x, is_forward).shape():
141+
return False
142+
return True
142143

144+
def _update_skip_opt_set(self):
145+
for i in range(self.op_size):
146+
op = self._ops[i]
147+
if op.type() == "fill_constant" and op.attr("force_cpu") == True:
148+
self._skip_opt.update(op.output_arg_names())
149+
150+
def release_memory(self):
143151
self._build_graph()
144152
self._dataflow_analyze()
153+
self._update_skip_opt_set()
154+
fwd_id = 0
155+
bwd_id = 0
156+
for i in range(self.op_size):
157+
op = self._ops[i]
158+
if op.type() in sub_block_ops:
159+
continue
160+
block_desc = op.block()
161+
is_forward = i < self._forward_num
162+
in_diff, out_diff = self._get_diff(self._live_in[i],
163+
self._live_out[i])
164+
can_optimize = filter(
165+
lambda x: self._check_var_validity(block_desc, x, is_forward),
166+
in_diff)
167+
if can_optimize:
168+
index = i + fwd_id + 1 if is_forward else i - self._forward_num + bwd_id + 1
169+
delete_op = block_desc.insert_op(index)
170+
delete_op.set_type("delete_var")
171+
delete_op.set_input("X", can_optimize)
172+
if is_forward:
173+
fwd_id += 1
174+
else:
175+
bwd_id += 1
176+
177+
def memory_optimize(self, level=0):
178+
def compare_shape(x_shape, cache_shape, opt_level):
179+
if opt_level == 0:
180+
return x_shape == cache_shape
181+
if opt_level == 1:
182+
if (x_shape[0] == -1) ^ (cache_shape[0] == -1):
183+
return False
184+
x_size = abs(reduce(lambda x, y: x * y, x_shape))
185+
cache_size = abs(reduce(lambda x, y: x * y, cache_shape))
186+
if x_size <= cache_size:
187+
return True
188+
return False
189+
190+
self._build_graph()
191+
self._dataflow_analyze()
192+
self._update_skip_opt_set()
145193
self.pool = []
146194
for i in range(self.op_size):
147195
op = self._ops[i]
148196
if op.type() in sub_block_ops:
149197
continue
150198
block_desc = op.block()
199+
self.current_block_desc = block_desc
151200
is_forward = i < self._forward_num
152201
if self.pool:
153202
defs_can_optimize = filter(
154-
lambda x: check_var_validity(block_desc, x, is_forward),
203+
lambda x: self._check_var_validity(block_desc, x, is_forward),
155204
self._defs[i])
156205
out_pair = [
157206
(x, self._find_var(block_desc, x, is_forward).shape())
@@ -164,7 +213,7 @@ def check_var_validity(block_desc, x, is_forward):
164213
for index, cache_pair in enumerate(self.pool):
165214
cache_var = cache_pair[0]
166215
cache_shape = cache_pair[1]
167-
if x_shape == cache_shape:
216+
if compare_shape(x_shape, cache_shape, level):
168217
if self._has_var(block_desc, cache_var, is_forward):
169218
x_dtype = self._find_var(block_desc, x,
170219
is_forward).dtype()
@@ -196,7 +245,7 @@ def check_var_validity(block_desc, x, is_forward):
196245
in_diff, out_diff = self._get_diff(self._live_in[i],
197246
self._live_out[i])
198247
can_optimize = filter(
199-
lambda x: check_var_validity(block_desc, x, is_forward),
248+
lambda x: self._check_var_validity(block_desc, x, is_forward),
200249
in_diff)
201250
if can_optimize:
202251
for var_name in can_optimize:
@@ -270,7 +319,8 @@ def _get_cfgs(input_program):
270319
([block_desc.op(i) for i in range(op_size)], op_size, set()))
271320

272321
sub_block_pair = [("while", "while_grad"), ("parallel_do",
273-
"parallel_do_grad")]
322+
"parallel_do_grad"),
323+
("conditional_block", "conditional_block_grad")]
274324

275325
ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
276326

@@ -281,9 +331,15 @@ def _get_cfgs(input_program):
281331
return cfgs
282332

283333

284-
def memory_optimize(input_program, print_log=False):
334+
def memory_optimize(input_program, print_log=False, level=0):
285335
global PRINT_LOG
286336
PRINT_LOG = print_log
287337
cfgs = _get_cfgs(input_program)
288338
for cfg in cfgs:
289-
cfg.memory_optimize()
339+
cfg.memory_optimize(level)
340+
341+
342+
def release_memory(input_program):
343+
cfgs = _get_cfgs(input_program)
344+
for cfg in cfgs:
345+
cfg.release_memory()

python/paddle/fluid/tests/book_memory_optimization/test_memopt_fit_a_line.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
sgd_optimizer.minimize(avg_cost)
5151

5252
fluid.memory_optimize(fluid.default_main_program(), print_log=True)
53+
# fluid.release_memory(fluid.default_main_program())
5354

5455
BATCH_SIZE = 200
5556

@@ -69,8 +70,6 @@
6970

7071
PASS_NUM = 100
7172
for pass_id in range(PASS_NUM):
72-
fluid.io.save_persistables(exe, "./fit_a_line.model/")
73-
fluid.io.load_persistables(exe, "./fit_a_line.model/")
7473
for data in train_reader():
7574
avg_loss_value, = exe.run(fluid.default_main_program(),
7675
feed=feeder.feed(data),

0 commit comments

Comments
 (0)