Skip to content

Commit 63651c1

Browse files
committed
fix grad desc maker
test=develop
1 parent a0f4fef commit 63651c1

28 files changed

+473
-426
lines changed

paddle/fluid/framework/details/reference_count_pass.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -335,6 +335,7 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
335335
var_name);
336336
ref_cnts[i].emplace(var_name, result.size());
337337
last_live_ops_of_vars[i].emplace(var_name, std::move(result));
338+
break;
338339
}
339340

340341
// Seldomly, all preceding trying failed.

paddle/fluid/operators/bpr_loss_op.cc

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

1515
#include "paddle/fluid/operators/bpr_loss_op.h"
16+
#include <memory>
1617

1718
namespace paddle {
1819
namespace operators {
@@ -127,14 +128,31 @@ neural networks>(https://arxiv.org/abs/1511.06939)
127128
)DOC");
128129
}
129130
};
131+
132+
class BprLossGradDescMaker : public framework::SingleGradOpDescMaker {
133+
public:
134+
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
135+
136+
protected:
137+
std::unique_ptr<framework::OpDesc> Apply() const override {
138+
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
139+
op->SetType("bpr_loss_grad");
140+
op->SetInput("X", Input("X"));
141+
op->SetInput("Label", Input("Label"));
142+
op->SetInput(framework::GradVarName("Y"), OutputGrad("Y"));
143+
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
144+
op->SetAttrMap(Attrs());
145+
return op;
146+
}
147+
};
130148
} // namespace operators
131149
} // namespace paddle
132150

133151
namespace ops = paddle::operators;
134152
using CPUCtx = paddle::platform::CPUDeviceContext;
135153

136154
REGISTER_OPERATOR(bpr_loss, ops::BprLossOp, ops::BprLossOpMaker,
137-
paddle::framework::DefaultGradOpDescMaker<true>);
155+
ops::BprLossGradDescMaker);
138156
REGISTER_OPERATOR(bpr_loss_grad, ops::BprLossGradientOp);
139157
REGISTER_OP_CPU_KERNEL(bpr_loss, ops::BprLossOpKernel<CPUCtx, float>,
140158
ops::BprLossOpKernel<CPUCtx, double>);
Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
include(operators)
22
register_operators(DEPS naive_executor)
3-
cc_library(while_op_helper SRCS while_op_helper.cc DEPS operator)
3+
cc_library(loop_op_helper SRCS loop_op_helper.cc DEPS operator)
44

55
file(APPEND ${pybind_file} "USE_OP(less_than);\nUSE_OP(logical_and);\nUSE_NO_KERNEL_OP(read_from_array);\n")

paddle/fluid/operators/controlflow/while_op.cc

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,28 +18,21 @@
1818
#include "paddle/fluid/framework/op_registry.h"
1919
#include "paddle/fluid/framework/operator.h"
2020
#include "paddle/fluid/framework/var_type.h"
21-
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
21+
#include "paddle/fluid/operators/controlflow/loop_op_helper.h"
2222
#include "paddle/fluid/operators/detail/safe_ref.h"
2323

2424
namespace paddle {
2525
namespace operators {
2626

27+
static constexpr char kCondition[] = "Condition";
28+
static constexpr char kStepScopes[] = "StepScopes";
29+
static constexpr char kX[] = "X";
30+
static constexpr char kXGRAD[] = "X@GRAD";
31+
static constexpr char kOutputs[] = "Out";
32+
2733
using StepScopeVar = std::vector<framework::Scope *>;
2834
using LoDTensor = framework::LoDTensor;
2935

30-
namespace { // NOLINT
31-
static std::string GetSkipEagerDeletionVarsDebugString(
32-
const std::vector<std::string> &vars) {
33-
std::string str = "Skip " + std::to_string(vars.size()) +
34-
" var(s) in eager deletion mode: ";
35-
for (auto &var : vars) {
36-
str.append(var);
37-
str.push_back(' ');
38-
}
39-
return str;
40-
}
41-
} // NOLINT
42-
4336
class WhileOp : public framework::OperatorBase {
4437
public:
4538
WhileOp(const std::string &type, const framework::VariableNameMap &inputs,

paddle/fluid/operators/controlflow/while_op_helper.cc

Lines changed: 0 additions & 291 deletions
This file was deleted.

0 commit comments

Comments
 (0)