Skip to content

Commit da338d7

Browse files
authored
Modify english document and unittest of while_loop (#22615) (#22629)
cherry-pick #22615
1 parent fd37536 commit da338d7

File tree

2 files changed

+44
-15
lines changed

2 files changed

+44
-15
lines changed

python/paddle/fluid/layers/control_flow.py

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
929929
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
930930
931931
Args:
932-
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping.
933-
body(Callable): A callable returning a tuple or list of tensors of the same arity (length and structure)
934-
and types as ``loops_vars`` .
935-
loop_vars(list|tuple): A list or tuple of tensors that is passed to both ``cond`` and ``body`` .
932+
cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
933+
as many arguments as ``loop_vars`` .
934+
body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
935+
(length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
936+
loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
936937
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
937938
name(str, optional): Normally there is no need for users to set this property. For more information, please
938939
refer to :ref:`api_guide_Name`. Default is None.
939940
940941
Returns:
941-
A list or tuple of tensors which returned by ``body`` .
942+
A list or tuple of tensors or LoDTensorArrays which returned by ``body`` .
942943
943944
Returen type:
944945
list(Variable)|tuple(Variable).
@@ -951,29 +952,31 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
951952
TypeError: If the type of ``cond`` returns is not a boolean variable.
952953
TypeError: If the shape of ``cond`` returns is not equals 1.
953954
ValueError: If the ``var_loops`` is empty.
955+
ValueError: If the length or type of ``body`` returns is not same as ``loop_vars``.
954956
955957
Examples:
956958
.. code-block:: python
957959
958960
import paddle.fluid as fluid
959961
import paddle.fluid.layers as layers
960962
961-
def cond(i):
962-
return layers.less_than(i, ten)
963+
def cond(i, ten):
964+
return i < ten
963965
964-
def body(i):
965-
return layers.increment(x=i, value=1, in_place=True)
966+
def body(i, ten):
967+
i = i + 1
968+
return [i, ten]
966969
967970
main_program = fluid.default_main_program()
968971
startup_program = fluid.default_startup_program()
969972
970973
with fluid.program_guard(main_program, startup_program):
971974
i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter
972975
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length
973-
out = layers.while_loop(cond, body, [i])
976+
i, ten = layers.while_loop(cond, body, [i, ten])
974977
975978
exe = fluid.Executor(fluid.CPUPlace())
976-
res = exe.run(main_program, feed={}, fetch_list=out)
979+
res = exe.run(main_program, feed={}, fetch_list=[i])
977980
print(res) # [array([10])]
978981
"""
979982
helper = LayerHelper('while_loop', **locals())
@@ -1000,11 +1003,13 @@ def body(i):
10001003
while_loop_block = While(pre_cond, is_test, name)
10011004
with while_loop_block.block():
10021005
output_vars = body(*loop_vars)
1006+
if not isinstance(output_vars, (list, tuple)):
1007+
output_vars = [output_vars]
1008+
if len(output_vars) != len(loop_vars):
1009+
raise ValueError("body in while_loop should return the same arity "
1010+
"(length and structure) and types as loop_vars")
1011+
now_cond = cond(*output_vars)
10031012
map_structure(assign, output_vars, loop_vars)
1004-
if len(loop_vars) == 1:
1005-
now_cond = cond(output_vars)
1006-
else:
1007-
now_cond = cond(*output_vars)
10081013
assign(now_cond, pre_cond)
10091014
return loop_vars
10101015

python/paddle/fluid/tests/unittests/test_while_loop_op.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,9 +311,19 @@ def cond_returns_bool_tensor(i):
311311
def cond_returns_2d_tensor(i):
312312
return layers.less_than(i, ten_2d)
313313

314+
def cond_receives_two_args(i, ten):
315+
return layers.less_than(i, ten)
316+
314317
def body(i):
315318
return layers.increment(i)
316319

320+
def body_returns_error_length(i):
321+
i = layers.increment(i)
322+
return [i, i]
323+
324+
def body_returns_error_type(i, ten):
325+
return layers.increment(i)
326+
317327
main_program = Program()
318328
startup_program = Program()
319329
with program_guard(main_program, startup_program):
@@ -367,6 +377,20 @@ def type_error_shape_cond_returns_2d():
367377

368378
self.assertRaises(TypeError, type_error_shape_cond_returns_2d)
369379

380+
# The length of `body` returns in Op(while_loop) must be same as `loop_vars`
381+
def value_error_body_returns_error_length():
382+
out = layers.while_loop(cond_returns_bool_tensor,
383+
body_returns_error_length, [data])
384+
385+
self.assertRaises(ValueError, value_error_body_returns_error_length)
386+
387+
# The type of `body` returns in Op(while_loop) must be same as `loop_vars`
388+
def value_error_body_returns_error_type():
389+
out = layers.while_loop(cond_receives_two_args,
390+
body_returns_error_type, [data, ten])
391+
392+
self.assertRaises(ValueError, value_error_body_returns_error_type)
393+
370394

371395
if __name__ == '__main__':
372396
unittest.main()

0 commit comments

Comments
 (0)