Skip to content

Commit 3a76062

Browse files
authored
support testing when training and handle dropout and batch_norm operator in testing mode (#5734)
* is_training to is_test in dropout op * handle dropout and batch_norm operator when prune pdesc in testing mode * handle dropout and batch_norm operator when prune pdesc in testing mode * add get_inference_program method * fix dropout op * fix ci * test data after each batch training * refine code * refine test_book3 * fix ci * follow comments
1 parent c9172c1 commit 3a76062

File tree

13 files changed

+141
-23
lines changed

13 files changed

+141
-23
lines changed

paddle/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id,
120120

121121
for (auto& op_desc : block.AllOps()) {
122122
auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
123-
VLOG(10) << op->DebugString();
123+
VLOG(3) << op->DebugString();
124124
op->Run(*local_scope, *device);
125125
}
126126
if (create_local_scope) {

paddle/framework/prune.cc

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ namespace framework {
2626

2727
const std::string kFeedOpType = "feed";
2828
const std::string kFetchOpType = "fetch";
29+
const std::string kDropOutOpType = "dropout";
30+
const std::string kBatchNormOpType = "batch_norm";
2931

3032
bool HasDependentVar(const OpDesc& op_desc,
3133
const std::set<std::string>& dependent_vars) {
@@ -106,5 +108,26 @@ void Prune(const ProgramDesc& input, ProgramDesc* output) {
106108
prune_impl(input, output, 0);
107109
}
108110

111+
void inference_optimize_impl(const ProgramDesc& input, ProgramDesc* output,
112+
int block_id) {
113+
*output = input;
114+
auto* op_field = output->mutable_blocks(block_id)->mutable_ops();
115+
for (auto& op_desc : *op_field) {
116+
if (op_desc.type() == kDropOutOpType ||
117+
op_desc.type() == kBatchNormOpType) {
118+
for (auto& attr : *op_desc.mutable_attrs()) {
119+
if (attr.name() == "is_test") {
120+
attr.set_b(true);
121+
break;
122+
}
123+
}
124+
}
125+
}
126+
}
127+
128+
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output) {
129+
inference_optimize_impl(input, output, 0);
130+
}
131+
109132
} // namespace framework
110133
} // namespace paddle

paddle/framework/prune.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,5 +22,7 @@ namespace framework {
2222

2323
void Prune(const ProgramDesc& input, ProgramDesc* output);
2424

25+
void InferenceOptimize(const ProgramDesc& input, ProgramDesc* output);
26+
2527
} // namespace framework
2628
} // namespace paddle

paddle/operators/dropout_op.cc

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ class DropoutOp : public framework::OperatorWithKernel {
3030

3131
auto x_dims = ctx->GetInputDim("X");
3232
ctx->SetOutputDim("Out", x_dims);
33-
if (ctx->Attrs().Get<bool>("is_training") == true) {
33+
if (ctx->Attrs().Get<bool>("is_test") == false) {
3434
ctx->SetOutputDim("Mask", x_dims);
3535
}
3636
ctx->ShareLoD("X", /*->*/ "Out");
@@ -49,7 +49,7 @@ class DropoutOpMaker : public framework::OpProtoAndCheckerMaker {
4949

5050
AddAttr<float>("dropout_prob", "Probability of setting units to zero.")
5151
.SetDefault(.5f);
52-
AddAttr<bool>("is_training", "True if in training phase.").SetDefault(true);
52+
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
5353
AddAttr<int>("seed", "Dropout random seed.").SetDefault(0);
5454

5555
AddComment(R"DOC(
@@ -71,8 +71,8 @@ class DropoutOpGrad : public framework::OperatorWithKernel {
7171
using framework::OperatorWithKernel::OperatorWithKernel;
7272

7373
void InferShape(framework::InferShapeContext* ctx) const override {
74-
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_training"), true,
75-
"GradOp is only callable when is_training is true");
74+
PADDLE_ENFORCE_EQ(ctx->Attrs().Get<bool>("is_test"), false,
75+
"GradOp is only callable when is_test is false");
7676

7777
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) must not be null.");
7878
PADDLE_ENFORCE(ctx->HasInput("Mask"), "Mask must not be null.");

paddle/operators/dropout_op.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class GPUDropoutKernel : public framework::OpKernel<T> {
5959
auto Y = EigenMatrix<T>::Reshape(*y, 1);
6060

6161
auto place = context.GetEigenDevice<Place>();
62-
if (context.Attr<bool>("is_training")) {
62+
if (!context.Attr<bool>("is_test")) {
6363
auto* mask = context.Output<Tensor>("Mask");
6464
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
6565
int size = framework::product(mask->dims());

paddle/operators/dropout_op.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ class CPUDropoutKernel : public framework::OpKernel<T> {
3535
auto* y_data = y->mutable_data<T>(context.GetPlace());
3636
float dropout_prob = context.Attr<float>("dropout_prob");
3737

38-
if (context.Attr<bool>("is_training")) {
38+
if (!context.Attr<bool>("is_test")) {
3939
auto* mask = context.Output<Tensor>("Mask");
4040
auto* mask_data = mask->mutable_data<T>(context.GetPlace());
4141
int seed = context.Attr<int>("seed");
@@ -65,8 +65,8 @@ template <typename Place, typename T>
6565
class DropoutGradKernel : public framework::OpKernel<T> {
6666
public:
6767
void Compute(const framework::ExecutionContext& context) const override {
68-
PADDLE_ENFORCE(context.Attr<bool>("is_training"),
69-
"GradOp is only callable when is_training is true");
68+
PADDLE_ENFORCE(!context.Attr<bool>("is_test"),
69+
"GradOp is only callable when is_test is false");
7070

7171
auto* grad_x = context.Output<Tensor>(framework::GradVarName("X"));
7272
auto* grad_y = context.Input<Tensor>(framework::GradVarName("Out"));

paddle/pybind/pybind.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,11 @@ All parameter, weight, gradient are variables in Paddle.
293293
Prune(*prog_with_targets.Proto(), &pruned_desc);
294294
return new ProgramDescBind(pruned_desc);
295295
});
296+
m.def("inference_optimize", [](ProgramDescBind &origin) {
297+
ProgramDesc pruned_desc;
298+
InferenceOptimize(*(origin.Proto()), &pruned_desc);
299+
return new ProgramDescBind(pruned_desc);
300+
});
296301
m.def_submodule(
297302
"var_names",
298303
"The module will return special predefined variable name in Paddle")

python/paddle/v2/fluid/evaluator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ def __init__(self, name, **kwargs):
3333
else:
3434
self._main_program = g_main_program
3535

36+
def states(self):
37+
return self._states
38+
3639
def _update_ops(self, *args, **kwargs):
3740
"""
3841
append update ops to the global states

python/paddle/v2/fluid/framework.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,13 @@ def prune(self, targets):
511511
res.sync_with_cpp()
512512
return res
513513

514+
def inference_optimize(self):
515+
res = Program()
516+
res.desc = core.inference_optimize(self.desc)
517+
res.blocks = [Block(res, i) for i in xrange(res.desc.num_blocks())]
518+
res.sync_with_cpp()
519+
return res
520+
514521
@staticmethod
515522
def parse_from_string(binary_str):
516523
p = Program()

python/paddle/v2/fluid/io.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66

77
__all__ = [
88
'save_vars', 'save_params', 'save_persistables', 'load_vars', 'load_params',
9-
'load_persistables', "save_inference_model", "load_inference_model"
9+
'load_persistables', "save_inference_model", "load_inference_model",
10+
"get_inference_program"
1011
]
1112

1213

@@ -151,6 +152,17 @@ def load_persistables(executor, dirname, main_program=None):
151152
predicate=is_persistable)
152153

153154

155+
def get_inference_program(target_vars, main_program=None):
156+
if main_program is None:
157+
main_program = g_main_program
158+
if not isinstance(target_vars, list):
159+
target_vars = [target_vars]
160+
161+
pruned_program = main_program.prune(targets=target_vars)
162+
inference_program = pruned_program.inference_optimize()
163+
return inference_program
164+
165+
154166
def save_inference_model(dirname,
155167
feeded_var_names,
156168
target_vars,
@@ -177,13 +189,14 @@ def save_inference_model(dirname,
177189
if not os.path.isdir(dirname):
178190
os.makedirs(dirname)
179191

180-
pruned_program = main_program.prune(target_vars)
192+
pruned_program = main_program.prune(targets=target_vars)
193+
inference_program = pruned_program.inference_optimize()
181194
fetch_var_names = [v.name for v in target_vars]
182195

183196
model_file_name = dirname + "/__model__"
184197
with open(model_file_name, "w") as f:
185198
pickle.dump({
186-
"program_desc_str": pruned_program.desc.serialize_to_string(),
199+
"program_desc_str": inference_program.desc.serialize_to_string(),
187200
"feed_var_names": feeded_var_names,
188201
"fetch_var_names": fetch_var_names
189202
}, f, -1)

0 commit comments

Comments
 (0)