Skip to content

Commit 781dc72

Browse files
authored
Merge pull request #13128 from chengduoZH/refine_while_op
Refine while op
2 parents 91e10fb + 16359da commit 781dc72

File tree

3 files changed

+14
-3
lines changed

3 files changed

+14
-3
lines changed

paddle/fluid/API.spec

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ paddle.fluid.layers.argsort ArgSpec(args=['input', 'axis', 'name'], varargs=None
190190
paddle.fluid.layers.ones ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
191191
paddle.fluid.layers.zeros ArgSpec(args=['shape', 'dtype', 'force_cpu'], varargs=None, keywords=None, defaults=(False,))
192192
paddle.fluid.layers.reverse ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=None)
193-
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'name'], varargs=None, keywords=None, defaults=(None,))
193+
paddle.fluid.layers.While.__init__ ArgSpec(args=['self', 'cond', 'is_test', 'name'], varargs=None, keywords=None, defaults=(False, None))
194194
paddle.fluid.layers.While.block ArgSpec(args=['self'], varargs=None, keywords=None, defaults=None)
195195
paddle.fluid.layers.Switch.__init__ ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,))
196196
paddle.fluid.layers.Switch.case ArgSpec(args=['self', 'condition'], varargs=None, keywords=None, defaults=None)

paddle/fluid/operators/while_op.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class WhileOp : public framework::OperatorBase {
5555
auto step_scopes =
5656
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
5757

58+
bool is_test = Attr<bool>("is_test");
5859
PADDLE_ENFORCE(platform::is_cpu_place(cond.place()),
5960
"Condition of while op must in CPU memory.");
6061
while (cond.data<bool>()[0]) {
@@ -63,6 +64,10 @@ class WhileOp : public framework::OperatorBase {
6364

6465
executor.Run(*program, &current_scope, block->ID(),
6566
false /*create_local_scope*/);
67+
68+
if (is_test) {
69+
scope.DeleteScope(&current_scope);
70+
}
6671
}
6772
}
6873
};
@@ -88,6 +93,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
8893
"variables generated in the i'th step.");
8994
AddAttr<framework::BlockDesc *>(kStepBlock,
9095
"The step block inside WhileOp");
96+
AddAttr<bool>("is_test", "True if in test phase.").SetDefault(false);
9197
AddComment(R"DOC(
9298
)DOC");
9399
}
@@ -103,6 +109,8 @@ class WhileGradOp : public framework::OperatorBase {
103109
private:
104110
void RunImpl(const framework::Scope &scope,
105111
const platform::Place &dev_place) const override {
112+
PADDLE_ENFORCE(!Attr<bool>("is_test"),
113+
"GradOp is only callable when is_test is false");
106114
// get device context from pool
107115
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
108116
auto &dev_ctx = *pool.Get(dev_place);

python/paddle/fluid/layers/control_flow.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -661,6 +661,7 @@ class While(object):
661661
662662
Args:
663663
cond (Variable): condition used to compare.
664+
is_test(bool): A flag indicating whether execution is in test phase.
664665
name (str): The name of this layer.
665666
666667
Examples:
@@ -683,7 +684,7 @@ class While(object):
683684
IN_WHILE_BLOCK = 1
684685
AFTER_WHILE_BLOCK = 2
685686

686-
def __init__(self, cond, name=None):
687+
def __init__(self, cond, is_test=False, name=None):
687688
self.helper = LayerHelper("while", name=name)
688689
self.status = While.BEFORE_WHILE_BLOCK
689690
if not isinstance(cond, Variable):
@@ -694,6 +695,7 @@ def __init__(self, cond, name=None):
694695
if reduce(lambda a, b: a * b, cond.shape, 1) != 1:
695696
raise TypeError("condition should be a bool scalar")
696697
self.cond_var = cond
698+
self.is_test = is_test
697699

698700
def block(self):
699701
return WhileGuard(self)
@@ -735,7 +737,8 @@ def _complete(self):
735737
},
736738
outputs={'Out': out_vars,
737739
'StepScopes': [step_scope]},
738-
attrs={'sub_block': while_block})
740+
attrs={'sub_block': while_block,
741+
"is_test": self.is_test})
739742

740743

741744
def lod_rank_table(x, level=0):

0 commit comments

Comments
 (0)