Skip to content

Commit 4c8254e

Browse files
committed
revert some loop op revision
test=develop
1 parent 16f0994 commit 4c8254e

File tree

11 files changed

+380
-40
lines changed

11 files changed

+380
-40
lines changed
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include(operators)
22
register_operators(DEPS naive_executor)
3-
cc_library(loop_op_helper SRCS loop_op_helper.cc DEPS operator)
3+
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
44

55
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")

paddle/fluid/operators/controlflow/while_op.cc

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,28 @@
1818
#include "paddle/fluid/framework/op_registry.h"
1919
#include "paddle/fluid/framework/operator.h"
2020
#include "paddle/fluid/framework/var_type.h"
21-
#include "paddle/fluid/operators/controlflow/loop_op_helper.h"
21+
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
2222
#include "paddle/fluid/operators/detail/safe_ref.h"
2323

2424
namespace paddle {
2525
namespace operators {
2626

27-
static constexpr char kCondition[] = "Condition";
28-
static constexpr char kStepScopes[] = "StepScopes";
29-
static constexpr char kX[] = "X";
30-
static constexpr char kXGRAD[] = "X@GRAD";
31-
static constexpr char kOutputs[] = "Out";
32-
3327
using StepScopeVar = std::vector<framework::Scope *>;
3428
using LoDTensor = framework::LoDTensor;
3529

30+
namespace { // NOLINT
31+
static std::string GetSkipEagerDeletionVarsDebugString(
32+
const std::vector<std::string> &vars) {
33+
std::string str = "Skip " + std::to_string(vars.size()) +
34+
" var(s) in eager deletion mode: ";
35+
for (auto &var : vars) {
36+
str.append(var);
37+
str.push_back(' ');
38+
}
39+
return str;
40+
}
41+
} // NOLINT
42+
3643
class WhileOp : public framework::OperatorBase {
3744
public:
3845
WhileOp(const std::string &type, const framework::VariableNameMap &inputs,
Lines changed: 291 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,291 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
16+
#include <string>
17+
#include <unordered_set>
18+
#include <utility>
19+
#include "paddle/fluid/framework/program_desc.h"
20+
21+
namespace paddle {
22+
namespace operators {
23+
24+
// OpVariant is a wrapper class of OpDesc and OperatorBase
25+
// So that API would be the same.
26+
class OpVariant {
27+
struct InputsVisitor
28+
: public boost::static_visitor<const framework::VariableNameMap *> {
29+
template <typename OpType>
30+
const framework::VariableNameMap *operator()(const OpType *op) const {
31+
return &(op->Inputs());
32+
}
33+
};
34+
35+
struct OutputsVisitor
36+
: public boost::static_visitor<const framework::VariableNameMap *> {
37+
template <typename OpType>
38+
const framework::VariableNameMap *operator()(const OpType *op) const {
39+
return &(op->Outputs());
40+
}
41+
};
42+
43+
struct AttributeMapVisitor
44+
: public boost::static_visitor<const framework::AttributeMap *> {
45+
const framework::AttributeMap *operator()(
46+
const framework::OpDesc *op) const {
47+
return &(op->GetAttrMap());
48+
}
49+
50+
const framework::AttributeMap *operator()(
51+
const framework::OperatorBase *op) const {
52+
return &(op->Attrs());
53+
}
54+
};
55+
56+
struct RawPointerVisitor : public boost::static_visitor<const void *> {
57+
template <typename OpType>
58+
const void *operator()(const OpType *op) const {
59+
return op;
60+
}
61+
};
62+
63+
public:
64+
OpVariant(const framework::OperatorBase *op) : op_(op) {} // NOLINT
65+
66+
OpVariant(const framework::OpDesc *op) : op_(op) {} // NOLINT
67+
68+
const framework::VariableNameMap &Inputs() const {
69+
return *boost::apply_visitor(InputsVisitor(), op_);
70+
}
71+
72+
const framework::VariableNameMap &Outputs() const {
73+
return *boost::apply_visitor(OutputsVisitor(), op_);
74+
}
75+
76+
const framework::AttributeMap &Attrs() const {
77+
return *boost::apply_visitor(AttributeMapVisitor(), op_);
78+
}
79+
80+
template <typename AttrType>
81+
const AttrType &Attr(const std::string &name) const {
82+
auto &attrs = Attrs();
83+
auto it = attrs.find(name);
84+
PADDLE_ENFORCE(it != attrs.end(), "Cannot find attribute %s", name);
85+
return boost::get<AttrType>(it->second);
86+
}
87+
88+
bool operator==(const OpVariant &other) const {
89+
return RawPointer() == other.RawPointer();
90+
}
91+
92+
const void *RawPointer() const {
93+
return boost::apply_visitor(RawPointerVisitor(), op_);
94+
}
95+
96+
int which() const { return static_cast<int>(op_.which()); }
97+
98+
struct Hasher {
99+
size_t operator()(const OpVariant &op) const {
100+
return reinterpret_cast<size_t>(op.RawPointer());
101+
}
102+
};
103+
104+
private:
105+
const boost::variant<const framework::OperatorBase *,
106+
const framework::OpDesc *>
107+
op_;
108+
};
109+
110+
static std::string GetDebugString(const std::vector<std::string> &names) {
111+
if (names.empty()) return "";
112+
std::string ret = names[0];
113+
for (size_t i = 1; i < names.size(); ++i) {
114+
ret += (" " + names[i]);
115+
}
116+
return ret;
117+
}
118+
119+
// Set skip variables of while_op and while_grad_op
120+
// These variables should be skipped when eager deletion enables.
121+
// It is because:
122+
// 1. while_grad_op needs some variables defined in while_op.
123+
// 2. while_grad_op needs variables from the previous time step.
124+
static void SetSkipVars(const OpVariant &op, std::vector<std::string> attr) {
125+
auto &attrs = const_cast<framework::AttributeMap &>(op.Attrs());
126+
VLOG(2) << "Prepare to skip " << attr.size()
127+
<< " var(s): " << GetDebugString(attr);
128+
attrs[kSkipEagerDeletionVars] = std::move(attr);
129+
}
130+
131+
// Check whether the forward while_op and while_grad_op match
132+
// The program may have many while_ops.
133+
static bool IsMatchedWhileOpAndWhileGradOp(const OpVariant &fwd_op,
134+
const OpVariant &grad_op) {
135+
return fwd_op.Inputs().at(kX) == grad_op.Inputs().at(kX) &&
136+
fwd_op.Outputs().at(kOutputs) == grad_op.Inputs().at(kOutputs);
137+
}
138+
139+
// Test whether the variable is skippable in forward while_op
140+
// The variable is skippable in while_op when the variable used in while_grad
141+
// is not from grad_block.
142+
static bool IsSkippableVar(const std::string &name,
143+
framework::BlockDesc *grad_block) {
144+
return name != framework::kEmptyVarName && !grad_block->HasVar(name);
145+
}
146+
147+
static void ModifyWhileOpAndWhileGradOpAttr(const OpVariant &fwd_op,
148+
const OpVariant &bwd_op) {
149+
auto *grad_block = bwd_op.Attr<framework::BlockDesc *>(kStepBlock);
150+
151+
// Find all skippable variables in forward while_op
152+
std::unordered_set<std::string> forward_skip_vars;
153+
for (auto *op_desc : grad_block->AllOps()) {
154+
for (auto &in_arg_name : op_desc->InputArgumentNames()) {
155+
if (IsSkippableVar(in_arg_name, grad_block)) {
156+
forward_skip_vars.insert(in_arg_name);
157+
}
158+
}
159+
160+
for (auto &out_arg_name : op_desc->OutputArgumentNames()) {
161+
if (IsSkippableVar(out_arg_name, grad_block)) {
162+
forward_skip_vars.insert(out_arg_name);
163+
}
164+
}
165+
}
166+
167+
SetSkipVars(fwd_op, std::vector<std::string>(forward_skip_vars.begin(),
168+
forward_skip_vars.end()));
169+
170+
// Find all skippable variables in while_grad_op
171+
// The skipped variables are those which would be used across time steps.
172+
auto &fwd_input = fwd_op.Inputs().at(kX);
173+
auto &in_grads = bwd_op.Outputs().at(framework::GradVarName(kX));
174+
PADDLE_ENFORCE_EQ(
175+
fwd_input.size(), in_grads.size(),
176+
"Backward input gradient number does not match forward input number.");
177+
178+
std::unordered_set<std::string> backward_skip_vars;
179+
for (size_t i = 0; i < in_grads.size(); ++i) {
180+
if (in_grads[i] == framework::kEmptyVarName) {
181+
continue;
182+
}
183+
backward_skip_vars.insert(in_grads[i]);
184+
backward_skip_vars.insert(framework::GradVarName(fwd_input[i]));
185+
}
186+
187+
SetSkipVars(bwd_op, std::vector<std::string>(backward_skip_vars.begin(),
188+
backward_skip_vars.end()));
189+
}
190+
191+
// Find all while_ops and while_grad_ops in the graph or program
192+
// The while_grad_op and while_op may located in different blocks
193+
// So we should traverse all blocks in the program and find them out.
194+
static void FindAllWhileAndWhileGradOp(std::vector<OpVariant> *while_ops,
195+
std::vector<OpVariant> *while_grad_ops) {
196+
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size());
197+
198+
if (while_ops->empty()) return;
199+
200+
const auto *program =
201+
while_ops->front().Attr<framework::BlockDesc *>(kStepBlock)->Program();
202+
for (size_t i = 1; i < program->Size(); ++i) {
203+
auto &block = program->Block(i);
204+
for (size_t j = 0; j < block.OpSize(); ++j) {
205+
auto *op = block.Op(j);
206+
if (op->Type() == "while") {
207+
while_ops->emplace_back(op);
208+
} else if (op->Type() == "while_grad") {
209+
while_grad_ops->emplace_back(op);
210+
}
211+
}
212+
}
213+
214+
PADDLE_ENFORCE_GE(while_ops->size(), while_grad_ops->size(),
215+
"There are extra while_grad ops in the graph or program");
216+
}
217+
218+
static void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(
219+
std::vector<OpVariant> *while_ops, std::vector<OpVariant> *while_grad_ops) {
220+
FindAllWhileAndWhileGradOp(while_ops, while_grad_ops);
221+
222+
VLOG(2) << "Found while op num: " << while_ops->size()
223+
<< ", while grad op num: " << while_grad_ops->size();
224+
225+
if (while_grad_ops->empty()) {
226+
return;
227+
}
228+
229+
std::unordered_set<OpVariant, OpVariant::Hasher> while_op_set(
230+
while_ops->begin(), while_ops->end());
231+
232+
for (auto &bwd_op : *while_grad_ops) {
233+
const OpVariant *matched_fwd_op = nullptr;
234+
for (auto &fwd_op : while_op_set) {
235+
if (IsMatchedWhileOpAndWhileGradOp(fwd_op, bwd_op)) {
236+
PADDLE_ENFORCE(matched_fwd_op == nullptr,
237+
"Found multiple matched while ops");
238+
matched_fwd_op = &fwd_op;
239+
}
240+
}
241+
PADDLE_ENFORCE_NOT_NULL(matched_fwd_op,
242+
"Cannot find matched forward while op.");
243+
ModifyWhileOpAndWhileGradOpAttr(*matched_fwd_op, bwd_op);
244+
while_op_set.erase(*matched_fwd_op);
245+
}
246+
}
247+
248+
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
249+
int block_id,
250+
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops) {
251+
// If block_id is not 0, returns
252+
// This is because all while_ops and while_grad_ops in the whole program
253+
// would be processed when block_id is 0 (i.e. when Executor::Run() or
254+
// ParallelExecutor constructs).
255+
256+
// What's more, all while_ops and while_grad_ops must be processed when
257+
// block_id is zero. If not, while_op may run first and erase variables
258+
// used in while_grad_op, and in this moment, while_grad_ops may be not
259+
// constructed yet.
260+
if (block_id != 0) return;
261+
262+
std::vector<OpVariant> fwd_ops, bwd_ops;
263+
for (auto &op : all_ops) {
264+
if (op->Type() == "while") {
265+
fwd_ops.emplace_back(op.get());
266+
} else if (op->Type() == "while_grad") {
267+
bwd_ops.emplace_back(op.get());
268+
}
269+
}
270+
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
271+
}
272+
273+
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
274+
const std::vector<framework::OperatorBase *> &while_ops,
275+
const std::vector<framework::OperatorBase *> &while_grad_ops) {
276+
std::vector<OpVariant> fwd_ops, bwd_ops;
277+
fwd_ops.reserve(while_ops.size());
278+
for (auto *op : while_ops) {
279+
fwd_ops.emplace_back(op);
280+
}
281+
282+
bwd_ops.reserve(while_grad_ops.size());
283+
for (auto *op : while_grad_ops) {
284+
bwd_ops.emplace_back(op);
285+
}
286+
287+
PrepareSafeEagerDeletionOnWhileOpAndWhileGradOpImpl(&fwd_ops, &bwd_ops);
288+
}
289+
290+
} // namespace operators
291+
} // namespace paddle
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
17+
#include <memory>
18+
#include <string>
19+
#include <vector>
20+
#include "paddle/fluid/framework/operator.h"
21+
#include "paddle/fluid/platform/variant.h"
22+
23+
namespace paddle {
24+
namespace operators {
25+
26+
static constexpr char kStepBlock[] = "sub_block";
27+
static constexpr char kCondition[] = "Condition";
28+
static constexpr char kStepScopes[] = "StepScopes";
29+
static constexpr char kX[] = "X";
30+
static constexpr char kXGRAD[] = "X@GRAD";
31+
static constexpr char kOutputs[] = "Out";
32+
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
33+
34+
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
35+
int block_id,
36+
const std::vector<std::unique_ptr<framework::OperatorBase>> &all_ops);
37+
38+
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
39+
const std::vector<framework::OperatorBase *> &while_ops,
40+
const std::vector<framework::OperatorBase *> &while_grad_ops);
41+
42+
} // namespace operators
43+
} // namespace paddle

paddle/fluid/operators/interpolate_op.cc

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
limitations under the License. */
1111

1212
#include "paddle/fluid/operators/interpolate_op.h"
13+
#include <memory>
1314
#include <string>
1415
#include <vector>
1516
#include "paddle/fluid/framework/op_registry.h"
@@ -209,6 +210,9 @@ class InterpolateGradDescMaker : public framework::SingleGradOpDescMaker {
209210
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
210211
op->SetType(ForwardOp().Type() + "_grad");
211212
op->SetInput("X", Input("X"));
213+
if (ForwardOp().Inputs().count("OutSize") > 0) {
214+
op->SetInput("OutSize", Input("OutSize"));
215+
}
212216
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
213217
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
214218
op->SetAttrMap(Attrs());

0 commit comments

Comments
 (0)