Skip to content

Commit c7c6eeb

Browse files
authored
Merge pull request #16409 from sneaxiy/feature/advance_gc
Enhance gc to support deleting tensor buffer in advance
2 parents 54a7357 + a0f4fef commit c7c6eeb

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1083
-381
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ cc_library(lod_tensor SRCS lod_tensor.cc DEPS ddim place tensor framework_proto
6363
cc_test(lod_tensor_test SRCS lod_tensor_test.cc DEPS lod_tensor memory)
6464
nv_test(lod_tensor_gpu_test SRCS lod_tensor_test.cu DEPS lod_tensor)
6565

66-
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory)
66+
cc_library(garbage_collector SRCS garbage_collector.cc DEPS device_context memory gflags glog)
6767

6868
cc_library(reader SRCS reader.cc DEPS lod_tensor ddim)
6969
cc_test(reader_test SRCS reader_test.cc DEPS reader)
@@ -164,6 +164,8 @@ else()
164164
set(NGRAPH_EXE_DEPS)
165165
endif()
166166

167+
cc_library(executor_gc_helper SRCS executor_gc_helper.cc DEPS scope proto_desc operator garbage_collector)
168+
167169
if(WITH_DISTRIBUTE)
168170
cc_library(executor SRCS executor.cc DEPS op_registry device_context scope framework_proto glog
169171
lod_rank_table feed_fetch_method sendrecvop_rpc ${GLOB_DISTRIBUTE_DEPS} graph_to_program_pass variable_helper ${NGRAPH_EXE_DEPS})
@@ -174,7 +176,7 @@ else()
174176
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
175177
endif()
176178

177-
target_link_libraries(executor garbage_collector while_op_helper)
179+
target_link_libraries(executor while_op_helper executor_gc_helper)
178180

179181
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
180182
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor
@@ -194,6 +196,7 @@ cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_con
194196
cc_test(var_type_inference_test SRCS var_type_inference_test.cc DEPS op_registry
195197
proto_desc)
196198
cc_test(inplace_op_inference_test SRCS inplace_op_inference_test.cc DEPS op_registry proto_desc op_info memory_optimize_helper)
199+
197200
cc_library(selected_rows SRCS selected_rows.cc DEPS tensor)
198201
cc_test(selected_rows_test SRCS selected_rows_test.cc DEPS selected_rows)
199202

paddle/fluid/framework/details/eager_deletion_pass.cc

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,9 @@
2222
#include "paddle/fluid/framework/details/computation_op_handle.h"
2323
#include "paddle/fluid/framework/details/eager_deletion_op_handle.h"
2424
#include "paddle/fluid/framework/details/multi_devices_helper.h"
25+
#include "paddle/fluid/framework/garbage_collector.h"
2526
#include "paddle/fluid/framework/ir/graph_helper.h"
2627

27-
DEFINE_double(memory_fraction_of_eager_deletion, 1.0,
28-
"Fraction of eager deletion. If less than 1.0, all variables in "
29-
"the program would be sorted according to its memory size, and "
30-
"only the FLAGS_memory_fraction_of_eager_deletion of the largest "
31-
"variables would be deleted.");
32-
3328
namespace paddle {
3429
namespace framework {
3530
namespace details {
@@ -206,8 +201,9 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
206201
}
207202
}
208203

209-
op_vars_map = ShrinkGCVars(op_vars_map, vars, places,
210-
FLAGS_memory_fraction_of_eager_deletion);
204+
double memory_fraction = framework::GetEagerDeletionMemoryFraction();
205+
206+
op_vars_map = ShrinkGCVars(op_vars_map, vars, places, memory_fraction);
211207

212208
for (auto &pair : op_vars_map) {
213209
auto *op = pair.first;
@@ -239,8 +235,7 @@ std::unique_ptr<ir::Graph> EagerDeletionPass::ApplyImpl(
239235
eager_deletion_op->AddOutput(dummy_leaf);
240236
}
241237

242-
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = "
243-
<< FLAGS_memory_fraction_of_eager_deletion;
238+
VLOG(10) << "FLAGS_memory_fraction_of_eager_deletion = " << memory_fraction;
244239
VLOG(10) << "Create " << op_vars_map.size() << " EagerDeletionOpHandle(s)";
245240

246241
auto while_op_eager_deletion_pass =

paddle/fluid/framework/details/early_delete_op_handle.h

Lines changed: 0 additions & 140 deletions
This file was deleted.

paddle/fluid/framework/details/op_registry.h

Lines changed: 88 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License. */
2121
#include <vector>
2222
#include "paddle/fluid/framework/grad_op_desc_maker.h"
2323
#include "paddle/fluid/framework/inplace_op_inference.h"
24+
#include "paddle/fluid/framework/no_need_buffer_vars_inference.h"
2425
#include "paddle/fluid/framework/op_info.h"
2526
#include "paddle/fluid/framework/op_proto_maker.h"
2627
#include "paddle/fluid/framework/operator.h"
@@ -36,27 +37,86 @@ enum OpInfoFillType {
3637
kGradOpDescMaker = 2,
3738
kVarTypeInference = 3,
3839
kShapeInference = 4,
39-
kInplaceOpInference = 5
40+
kInplaceOpInference = 5,
41+
kNoNeedBufferVarsInference = 6,
42+
kUnknown = -1
4043
};
4144

45+
namespace internal {
46+
template <typename T, OpInfoFillType kType>
47+
struct TypePair {
48+
using Type = T;
49+
static constexpr OpInfoFillType kFillType = kType;
50+
};
51+
52+
using OpRegistryClasses = std::tuple< // NOLINT
53+
TypePair<OperatorBase, kOperator>, // NOLINT
54+
TypePair<OpProtoAndCheckerMaker, kOpProtoAndCheckerMaker>, // NOLINT
55+
TypePair<GradOpDescMakerBase, kGradOpDescMaker>, // NOLINT
56+
TypePair<VarTypeInference, kVarTypeInference>, // NOLINT
57+
TypePair<InferShapeBase, kShapeInference>, // NOLINT
58+
TypePair<InplaceOpInference, kInplaceOpInference>, // NOLINT
59+
TypePair<NoNeedBufferVarsInference, kNoNeedBufferVarsInference> // NOLINT
60+
>;
61+
62+
static constexpr int kOpRegistryClassNumber =
63+
std::tuple_size<OpRegistryClasses>::value;
64+
65+
template <typename T, int kPos, bool kIsBounded /* = true*/>
66+
struct IsMatchedBaseTypeImpl {
67+
using PairType = typename std::tuple_element<kPos, OpRegistryClasses>::type;
68+
static constexpr bool kValue =
69+
std::is_base_of<typename PairType::Type, T>::value;
70+
};
71+
72+
template <typename T, int kPos>
73+
struct IsMatchedBaseTypeImpl<T, kPos, false> {
74+
static constexpr bool kValue = false;
75+
};
76+
77+
template <typename T, int kPos>
78+
static inline constexpr bool IsMatchedBaseType() {
79+
return IsMatchedBaseTypeImpl<
80+
T, kPos, (kPos >= 0 && kPos < kOpRegistryClassNumber)>::kValue;
81+
}
82+
83+
template <typename T, int kStart, int kEnd, bool kIsEnd, bool kIsMatched>
84+
struct OpInfoFillTypeGetterImpl {};
85+
86+
// This case should not happen
87+
template <typename T, int kStart, int kEnd>
88+
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, true> {};
89+
90+
template <typename T, int kStart, int kEnd>
91+
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, true, false> {
92+
static constexpr OpInfoFillType kType = kUnknown;
93+
};
94+
95+
template <typename T, int kStart, int kEnd>
96+
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, false> {
97+
static constexpr OpInfoFillType kType =
98+
OpInfoFillTypeGetterImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd,
99+
IsMatchedBaseType<T, kStart + 1>()>::kType;
100+
};
101+
102+
template <typename T, int kStart, int kEnd>
103+
struct OpInfoFillTypeGetterImpl<T, kStart, kEnd, false, true> {
104+
using PairType = typename std::tuple_element<kStart, OpRegistryClasses>::type;
105+
static constexpr OpInfoFillType kType = PairType::kFillType;
106+
};
107+
108+
template <typename T>
109+
using OpInfoFillTypeGetter =
110+
OpInfoFillTypeGetterImpl<T, 0, kOpRegistryClassNumber,
111+
kOpRegistryClassNumber == 0,
112+
IsMatchedBaseType<T, 0>()>;
113+
114+
} // namespace internal
115+
42116
template <typename T>
43117
struct OpInfoFillTypeID {
44118
static constexpr OpInfoFillType ID() {
45-
return std::is_base_of<OperatorBase, T>::value
46-
? kOperator
47-
: (std::is_base_of<OpProtoAndCheckerMaker, T>::value
48-
? kOpProtoAndCheckerMaker
49-
: (std::is_base_of<GradOpDescMakerBase, T>::value
50-
? kGradOpDescMaker
51-
: (std::is_base_of<VarTypeInference, T>::value
52-
? kVarTypeInference
53-
: (std::is_base_of<InferShapeBase, T>::value
54-
? kShapeInference
55-
: (std::is_base_of<
56-
InplaceOpInference, T>::value
57-
? kInplaceOpInference
58-
: static_cast<OpInfoFillType>(
59-
-1))))));
119+
return internal::OpInfoFillTypeGetter<T>::kType;
60120
}
61121
};
62122

@@ -156,6 +216,18 @@ struct OpInfoFiller<T, kInplaceOpInference> {
156216
}
157217
};
158218

219+
template <typename T>
220+
struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
221+
void operator()(const char* op_type, OpInfo* info) const {
222+
info->infer_no_need_buffer_vars_ = [](const VariableNameMap& inputs,
223+
const VariableNameMap& outputs,
224+
const AttributeMap& attrs) {
225+
T infer(inputs, outputs, attrs);
226+
return infer();
227+
};
228+
}
229+
};
230+
159231
} // namespace details
160232

161233
} // namespace framework

0 commit comments

Comments
 (0)