Skip to content

Commit 0a63234

Browse files
committed
follow comments. test=develop
1 parent 9e87fbe commit 0a63234

File tree

12 files changed

+70
-55
lines changed

12 files changed

+70
-55
lines changed

paddle/fluid/framework/details/build_strategy.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ class ParallelExecutorPassBuilder : public ir::PassBuilder {
5353
AppendPass("fuse_relu_depthwise_conv_pass");
5454
}
5555

56+
// NOTE(dzhwinter): A note for automatical inplace.
57+
// 1. modify program desc passes should put
58+
// before inplace pass.
59+
// 2. manually configured inplace should put
60+
// before inplace_pass
61+
5662
// Add automatically inplace.
5763
if (strategy_.enable_inplace_) {
5864
AppendPass("inplace_pass");

paddle/fluid/framework/details/build_strategy.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ struct BuildStrategy {
8080

8181
bool memory_early_delete_{false};
8282

83+
// TODO(dzhwinter):
84+
// make enable_inplace, memory_optimize_
85+
// memory_early_delete_ true by default
8386
bool enable_inplace_{false};
8487

8588
bool enable_sequential_execution_{false};

paddle/fluid/framework/details/graph_print_pass.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ namespace details {
2626
constexpr char kGraphvizPath[] = "debug_graphviz_path";
2727
constexpr char kGraphviz[] = "graphviz";
2828

29+
// NOTE(dzhwinter): If the graph contains circles.
30+
// the graph can not be topology sort.
31+
// This printer will print the whole graph
32+
// and highlight the circles. It's quite useful
33+
// for debug the deadlock and circles.
2934
class GraphvizNode {
3035
public:
3136
GraphvizNode(ir::Node* n, const int& i) : node_(n), id_(i) {}
@@ -37,7 +42,7 @@ class GraphvizNode {
3742
ir::Node* node_;
3843
int id_;
3944
};
40-
class GraphvizNode;
45+
4146
typedef std::unordered_set<std::unique_ptr<GraphvizNode>> GraphvizNodes;
4247

4348
class SSAGraphPrinter {

paddle/fluid/framework/details/memory_optimize_helper.cc

Lines changed: 28 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,23 +13,27 @@
1313
// limitations under the License.
1414

1515
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
16+
#include <functional>
1617
#include <iostream>
18+
#include <numeric>
1719
#include <sstream>
1820
#include <string>
1921

2022
namespace paddle {
2123
namespace framework {
2224
namespace details {
2325

26+
size_t NodeSizeInBytes(const VarDesc& node) {
27+
auto shape = node.GetShape();
28+
int size =
29+
std::accumulate(shape.begin(), shape.end(), 1, std::multiplies<int>());
30+
size_t type_size = SizeOfType(node.GetDataType());
31+
return type_size * std::abs(size);
32+
}
33+
2434
size_t NodeSizeInBytes(ir::Node* n) {
2535
auto* desc = FindVarDescInBlock(n);
26-
auto shape = desc->GetShape();
27-
size_t type_size = SizeOfType(desc->GetDataType());
28-
int size = 1;
29-
for (auto& s : shape) {
30-
size *= s;
31-
}
32-
return type_size * std::abs(size);
36+
return NodeSizeInBytes(*desc);
3337
}
3438

3539
std::string DebugStringImpl(VarDesc* var) {
@@ -154,23 +158,28 @@ std::string OrderedNodeList::ToString() const {
154158

155159
bool NodeCanReused(ir::Node* node) {
156160
if (node == nullptr || !node->IsVar() || node->IsCtrlVar()) return false;
157-
auto* desc = node->Var();
158-
auto type = desc->GetType();
159-
if (desc->Persistable() || type != proto::VarType::LOD_TENSOR ||
160-
desc->GetShape().empty()) {
161-
return false;
162-
}
163-
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
164-
std::string name = node->Name();
165-
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
166-
return false;
161+
// auto* desc = node->Var();
162+
bool flag = NodeCanReused(*node->Var());
167163
for (auto* op : node->inputs) {
168164
if (op->Op()->HasAttr("force_cpu")) {
169165
// op output force generated in cpu, can not be reused.
170-
return framework::AttrReader(op->Op()->GetAttrMap())
171-
.Get<bool>("force_cpu") == 0;
166+
flag &= framework::AttrReader(op->Op()->GetAttrMap())
167+
.Get<bool>("force_cpu") == 0;
172168
}
173169
}
170+
return flag;
171+
}
172+
173+
bool NodeCanReused(const VarDesc& node) {
174+
auto type = node.GetType();
175+
if (node.Persistable() || type != proto::VarType::LOD_TENSOR ||
176+
node.GetShape().empty()) {
177+
return false;
178+
}
179+
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
180+
std::string name = node.Name();
181+
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
182+
return false;
174183
return true;
175184
}
176185

paddle/fluid/framework/details/memory_optimize_helper.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,18 @@ class OrderedNodeList {
8686
// valid a tensor can be reuse or not
8787
bool NodeCanReused(ir::Node* node);
8888

89+
// valid a tensor can be reuse or not.
90+
bool NodeCanReused(const VarDesc& node);
91+
8992
// check op has subblock or not
9093
bool OpHasSubBlock(OpDesc* desc);
9194

9295
// node memory size in bytes
9396
size_t NodeSizeInBytes(ir::Node* n);
9497

98+
// node memory size in bytes
99+
size_t NodeSizeInBytes(const VarDesc&);
100+
95101
std::string DebugString(ir::Node* var);
96102

97103
VarDesc* FindVarDescInBlock(ir::Node* n);

paddle/fluid/framework/inplace_op_inference.h

Lines changed: 4 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
#include <unordered_map>
2020
#include "glog/logging.h"
2121
#include "paddle/fluid/framework/block_desc.h"
22+
#include "paddle/fluid/framework/details/memory_optimize_helper.h"
2223
#include "paddle/fluid/framework/op_desc.h"
2324
#include "paddle/fluid/framework/type_defs.h"
2425

@@ -66,30 +67,9 @@ class InplaceInToOut : public InplaceOpInference {
6667
const OpDesc& op_desc, BlockDesc* block) const = 0;
6768

6869
bool TryInplaceInputOutput(const VarDesc& in, const VarDesc& out) const {
69-
auto var_can_reused = [&](const VarDesc& node) -> bool {
70-
auto type = node.GetType();
71-
if (node.Persistable() || type != proto::VarType::LOD_TENSOR ||
72-
node.GetShape().empty()) {
73-
return false;
74-
}
75-
// vars can be @EMPTY@, @LR_DECAY_REUSE_ID@. For example, while_grad
76-
std::string name = node.Name();
77-
if (!name.empty() && name[0] == '@' && name[name.size() - 1] == '@')
78-
return false;
79-
return true;
80-
};
81-
82-
auto var_size_in_bytes = [&](const VarDesc& node) -> size_t {
83-
auto shape = node.GetShape();
84-
int size = std::accumulate(shape.begin(), shape.end(), 1,
85-
std::multiplies<int>());
86-
size_t type_size = SizeOfType(node.GetDataType());
87-
return type_size * std::abs(size);
88-
};
89-
90-
return in.Name() != out.Name() && var_can_reused(in) &&
91-
var_can_reused(out) &&
92-
var_size_in_bytes(out) <= var_size_in_bytes(in);
70+
return in.Name() != out.Name() && details::NodeCanReused(in) &&
71+
details::NodeCanReused(out) &&
72+
details::NodeSizeInBytes(out) <= details::NodeSizeInBytes(in);
9373
}
9474
};
9575

python/paddle/fluid/compiler.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,11 @@ def _compile_data_parallel(self):
174174
self._exec_strategy.num_threads = cpu_num * 2
175175

176176
trainers_endpoints = self._program._trainers_endpoints
177+
178+
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
179+
# if turn on python memory optimize, turn off the inplace_pass.
180+
self._build_strategy.enable_inplace = False if self._program._is_mem_optimized else True
181+
177182
if self._build_strategy.num_trainers > 1 and trainers_endpoints:
178183
assert self._build_strategy.num_trainers == len(
179184
trainers_endpoints), "num_trainers == len(end_points)"

python/paddle/fluid/framework.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1725,18 +1725,19 @@ def __init__(self):
17251725
self._trainers_endpoints = []
17261726
# the distributed lookup table names
17271727
self._distributed_lookup_table = None
1728+
# @deprecated(the python memory optimize transpiler is deprecated)
17281729
# whether the program is optimized by memory_optimize_transpiler
1729-
self.__is_optimized = False
1730+
self.__is_mem_optimized = False
17301731

17311732
@property
1732-
def _is_optimized(self):
1733+
def _is_mem_optimized(self):
17331734
# if the program is optimized, operator input/outputs
17341735
# maybe same, which conflict with save_inference_model.
1735-
return self.__is_optimized
1736+
return self.__is_mem_optimized
17361737

1737-
@_is_optimized.setter
1738-
def _is_optimized(self, target):
1739-
self.__is_optimized = target
1738+
@_is_mem_optimized.setter
1739+
def _is_mem_optimized(self, target):
1740+
self.__is_mem_optimized = target
17401741

17411742
@property
17421743
def op_role(self):

python/paddle/fluid/io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -931,7 +931,7 @@ def save_inference_model(dirname,
931931

932932
if main_program is None:
933933
main_program = default_main_program()
934-
if main_program._is_optimized:
934+
if main_program._is_mem_optimized:
935935
warnings.warn(
936936
"save_inference_model must put before you call memory_optimize. \
937937
the memory_optimize will modify the original program, \

python/paddle/fluid/parallel_executor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def __init__(self,
148148
else framework.default_main_program()
149149
# FIXME(dzhwinter): enable_inplace should be after memory_optimize
150150
# if turn on python memory optimize, turn off the inplace_pass.
151-
build_strategy.enable_inplace = False if main._is_optimized else True
151+
build_strategy.enable_inplace = False if main._is_mem_optimized else True
152152
scope = scope if scope is not None else executor.global_scope()
153153

154154
if share_vars_from and not isinstance(share_vars_from,

0 commit comments

Comments
 (0)