Skip to content

Commit cb74dac

Browse files
authored
[Cherry-pick] Support memory eager deletion on recurrent OP (#19411)
* Support memory eager deletion on recurrent OP (#17710) Test PaddingRNN on V100 GPU device. Test configuration: large model, padding mode (which is the mode using recurrentOp), one GPU. GPU memory (MiB): 6414 (this PR) vs 6837 (without this PR) Speed (steps/s): 10.28 (this PR) vs 9.89 (without this PR) * Fix random test_recurrent_op failure (#18718) The change includes 3 things: 1. Set CPU_NUM to 1 in the tests because the ParallelExecutor will print warning that CPU_NUM is not set and use default 1. 2. Old tests compare two RNNs, hand written simple RNN and same RNN built by Paddle, but initialized RNN weights in numpy random and Paddle random separately. Fixed it by setting weights and bias values. 3. Also set numpy random seed in the tests. Now the two RNNs diff can be smaller (rtol from 0.1, 0.2 to. 0.01) in the tests.
1 parent a7a4b72 commit cb74dac

19 files changed

+2640
-630
lines changed

paddle/fluid/framework/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ else()
196196
cc_test(test_naive_executor SRCS naive_executor_test.cc DEPS naive_executor elementwise_add_op)
197197
endif()
198198

199-
target_link_libraries(executor while_op_helper executor_gc_helper)
199+
target_link_libraries(executor while_op_helper executor_gc_helper recurrent_op_helper)
200200

201201
cc_library(parallel_executor SRCS parallel_executor.cc DEPS
202202
threaded_ssa_graph_executor scope_buffered_ssa_graph_executor parallel_ssa_graph_executor async_ssa_graph_executor

paddle/fluid/framework/executor.cc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ limitations under the License. */
3030
#include "paddle/fluid/framework/trainer_factory.h"
3131
#include "paddle/fluid/framework/transfer_scope_cache.h"
3232
#include "paddle/fluid/framework/variable_helper.h"
33+
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
3334
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
3435
#include "paddle/fluid/operators/distributed/distributed.h"
3536
#include "paddle/fluid/platform/place.h"
@@ -410,6 +411,8 @@ void Executor::RunPreparedContext(ExecutorPrepareContext* ctx, Scope* scope,
410411
if (gc && ctx->prog_.Size() > 1) {
411412
operators::PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(ctx->block_id_,
412413
ctx->ops_);
414+
operators::PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(
415+
ctx->block_id_, ctx->ops_);
413416
}
414417
}
415418

paddle/fluid/framework/ir/memory_optimize_pass/CMakeLists.txt

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
cc_library(op_graph_view SRCS op_graph_view.cc DEPS op_handle_base)
22
cc_library(while_op_eager_deletion_pass SRCS while_op_eager_deletion_pass.cc DEPS while_op_helper graph_helper pass computation_op_handle)
3+
cc_library(recurrent_op_eager_deletion_pass SRCS recurrent_op_eager_deletion_pass.cc DEPS recurrent_op_helper graph_helper pass computation_op_handle)
34
cc_library(reference_count_pass_helper SRCS reference_count_pass_helper.cc DEPS garbage_collector computation_op_handle var_handle)
45
cc_library(reference_count_pass SRCS reference_count_pass.cc DEPS computation_op_handle graph graph_helper pass op_graph_view reference_count_pass_helper)
56

@@ -14,5 +15,6 @@ cc_library(inplace_op_pass SRCS inplace_op_pass.cc DEPS memory_optimize_pass op_
1415

1516
cc_test(memory_optimize_helper_test SRCS memory_optimize_helper_test.cc memory_optimize_helper.cc DEPS framework_proto graph graph_helper op_registry)
1617

17-
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass reference_count_pass_helper)
18+
cc_library(eager_deletion_pass SRCS eager_deletion_pass.cc DEPS computation_op_handle
19+
eager_deletion_op_handle graph graph_helper pass while_op_eager_deletion_pass recurrent_op_eager_deletion_pass reference_count_pass_helper)
1820
cc_library(record_skip_memory_opt_vars_pass SRCS record_skip_memory_opt_vars_pass.cc DEPS graph graph_helper)

paddle/fluid/framework/ir/memory_optimize_pass/eager_deletion_pass.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,10 @@ void EagerDeletionPass::ApplyImpl(ir::Graph *graph) const {
266266
auto while_op_eager_deletion_pass =
267267
ir::PassRegistry::Instance().Get("while_op_eager_deletion_pass");
268268
while_op_eager_deletion_pass->Apply(graph);
269+
270+
auto recurrent_op_eager_deletion_pass =
271+
ir::PassRegistry::Instance().Get("recurrent_op_eager_deletion_pass");
272+
recurrent_op_eager_deletion_pass->Apply(graph);
269273
}
270274

271275
} // namespace ir
@@ -279,3 +283,4 @@ REGISTER_PASS(eager_deletion_pass, paddle::framework::ir::EagerDeletionPass)
279283
.RequirePassAttr(paddle::framework::ir::kGarbageCollector);
280284

281285
USE_PASS(while_op_eager_deletion_pass);
286+
USE_PASS(recurrent_op_eager_deletion_pass);
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/framework/ir/memory_optimize_pass/recurrent_op_eager_deletion_pass.h"
16+
17+
#include <unordered_map>
18+
#include <vector>
19+
20+
#include "paddle/fluid/framework/details/computation_op_handle.h"
21+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
22+
#include "paddle/fluid/framework/ir/graph_helper.h"
23+
#include "paddle/fluid/string/string_helper.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
namespace ir {
28+
29+
using paddle::operators::OpVariant;
30+
using paddle::operators::OpVariantSet;
31+
using paddle::operators::OpAndGradOpPair;
32+
33+
void RecurrentOpEagerDeletionPass::ApplyImpl(Graph *graph) const {
34+
// Find all recurrent_op and recurrent_grad_op in graph
35+
// Note the graph only contains ops and block 0
36+
std::unordered_map<size_t, OpAndGradOpPair> target_ops =
37+
DeviceIdToRecurrentAndRecurrentGradOp(*graph);
38+
39+
for (auto &entry : target_ops) {
40+
// Prepare safe eager deletion on different devices because the garbage
41+
// collection may be different across devices
42+
OpAndGradOpPair &op_pair = entry.second;
43+
PrepareSafeEagerDeletionOnRecurrentOpAndRecurrentGradOp(&op_pair);
44+
}
45+
}
46+
47+
// Returns a std::unordered_map mapping from the device id to recurrent op and
48+
// grad op pair
49+
std::unordered_map<size_t, OpAndGradOpPair>
50+
RecurrentOpEagerDeletionPass::DeviceIdToRecurrentAndRecurrentGradOp(
51+
const Graph &graph) const {
52+
std::unordered_map<size_t, OpAndGradOpPair> ret;
53+
std::vector<details::OpHandleBase *> all_ops =
54+
FilterByNodeWrapper<details::OpHandleBase>(graph);
55+
56+
for (auto *op : all_ops) {
57+
auto compute_op = dynamic_cast<details::ComputationOpHandle *>(op);
58+
if (compute_op == nullptr) continue;
59+
60+
if (compute_op->Name() == "recurrent") {
61+
// GetScopeIdx() returns device/place id
62+
ret[compute_op->GetScopeIdx()].first.emplace(compute_op->GetOp());
63+
} else if (compute_op->Name() == "recurrent_grad") {
64+
// GetScopeIdx() returns device/place id
65+
ret[compute_op->GetScopeIdx()].second.emplace(compute_op->GetOp());
66+
}
67+
}
68+
return ret;
69+
}
70+
71+
} // namespace ir
72+
} // namespace framework
73+
} // namespace paddle
74+
75+
REGISTER_PASS(recurrent_op_eager_deletion_pass,
76+
paddle::framework::ir::RecurrentOpEagerDeletionPass);
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <unordered_map>
18+
19+
#include "paddle/fluid/framework/details/computation_op_handle.h"
20+
#include "paddle/fluid/framework/details/multi_devices_helper.h"
21+
#include "paddle/fluid/framework/ir/graph_helper.h"
22+
#include "paddle/fluid/operators/controlflow/op_variant.h"
23+
#include "paddle/fluid/operators/controlflow/recurrent_op_helper.h"
24+
25+
namespace paddle {
26+
namespace framework {
27+
namespace ir {
28+
29+
// Pass class set skip eager deletion vars for recurrent ops
30+
class RecurrentOpEagerDeletionPass : public Pass {
31+
protected:
32+
void ApplyImpl(Graph *graph) const override;
33+
34+
private:
35+
// Returns a std::unordered_map mapping from the device id to recurrent op and
36+
// grad op pair
37+
std::unordered_map<size_t, paddle::operators::OpAndGradOpPair>
38+
DeviceIdToRecurrentAndRecurrentGradOp(const Graph &graph) const;
39+
};
40+
41+
} // namespace ir
42+
} // namespace framework
43+
} // namespace paddle
Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
include(operators)
22
register_operators(DEPS naive_executor)
3-
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
3+
cc_library(op_variant SRCS op_variant.cc DEPS operator proto_desc)
4+
cc_library(recurrent_op_helper SRCS recurrent_op_helper.cc DEPS operator op_variant recurrent_op)
5+
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator op_variant)
46

57
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/controlflow/op_variant.h"
16+
17+
namespace paddle {
18+
namespace operators {
19+
20+
struct InputsVisitor
21+
: public boost::static_visitor<const framework::VariableNameMap *> {
22+
template <typename OpType>
23+
const framework::VariableNameMap *operator()(const OpType *op) const {
24+
return &(op->Inputs());
25+
}
26+
};
27+
28+
struct OutputsVisitor
29+
: public boost::static_visitor<const framework::VariableNameMap *> {
30+
template <typename OpType>
31+
const framework::VariableNameMap *operator()(const OpType *op) const {
32+
return &(op->Outputs());
33+
}
34+
};
35+
36+
struct AttributeMapVisitor
37+
: public boost::static_visitor<const framework::AttributeMap *> {
38+
const framework::AttributeMap *operator()(const framework::OpDesc *op) const {
39+
return &(op->GetAttrMap());
40+
}
41+
42+
const framework::AttributeMap *operator()(
43+
const framework::OperatorBase *op) const {
44+
return &(op->Attrs());
45+
}
46+
};
47+
48+
struct RawPointerVisitor : public boost::static_visitor<const void *> {
49+
template <typename OpType>
50+
const void *operator()(const OpType *op) const {
51+
return op;
52+
}
53+
};
54+
55+
const framework::VariableNameMap &OpVariant::Inputs() const {
56+
return *boost::apply_visitor(InputsVisitor(), op_);
57+
}
58+
59+
const framework::VariableNameMap &OpVariant::Outputs() const {
60+
return *boost::apply_visitor(OutputsVisitor(), op_);
61+
}
62+
63+
const framework::AttributeMap &OpVariant::Attrs() const {
64+
return *boost::apply_visitor(AttributeMapVisitor(), op_);
65+
}
66+
67+
const void *OpVariant::RawPointer() const {
68+
return boost::apply_visitor(RawPointerVisitor(), op_);
69+
}
70+
71+
} // namespace operators
72+
} // namespace paddle
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <string>
18+
19+
#include "paddle/fluid/framework/operator.h"
20+
#include "paddle/fluid/framework/program_desc.h"
21+
#include "paddle/fluid/platform/variant.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
// OpVariant is a wrapper class of OpDesc and OperatorBase pointer
27+
// So that API would be the same.
28+
class OpVariant {
29+
public:
30+
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
31+
32+
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
33+
34+
const framework::VariableNameMap &Inputs() const;
35+
36+
const framework::VariableNameMap &Outputs() const;
37+
38+
const framework::AttributeMap &Attrs() const;
39+
40+
const void *RawPointer() const;
41+
42+
template <typename AttrType>
43+
const AttrType &Attr(const std::string &name) const {
44+
auto &attrs = Attrs();
45+
auto it = attrs.find(name);
46+
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
47+
return boost::get<AttrType>(it->second);
48+
}
49+
50+
bool operator==(const OpVariant &other) const {
51+
return RawPointer() == other.RawPointer();
52+
}
53+
54+
int which() const { return static_cast<int>(op_.which()); }
55+
56+
struct Hasher {
57+
size_t operator()(const OpVariant &op) const {
58+
return reinterpret_cast<size_t>(op.RawPointer());
59+
}
60+
};
61+
62+
private:
63+
const boost::variant<const framework::OperatorBase *,
64+
const framework::OpDesc *>
65+
op_;
66+
};
67+
68+
} // namespace operators
69+
} // namespace paddle

0 commit comments

Comments
 (0)