@@ -929,16 +929,17 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
929
929
while_loop is one of the control flows. Repeats while_loop `body` until `cond` returns False.
930
930
931
931
Args:
932
- cond(Callable): A callable returning a boolean tensor controlling whether to continue looping.
933
- body(Callable): A callable returning a tuple or list of tensors of the same arity (length and structure)
934
- and types as ``loops_vars`` .
935
- loop_vars(list|tuple): A list or tuple of tensors that is passed to both ``cond`` and ``body`` .
932
+ cond(Callable): A callable returning a boolean tensor controlling whether to continue looping. And ``cond`` takes
933
+ as many arguments as ``loop_vars`` .
934
+ body(Callable): A callable returning a tuple or list of tensors or LoDTensorArrays of the same arity
935
+ (length and structure) and types as ``loops_vars`` . And ``body`` takes as many arguments as ``loop_vars`` .
936
+ loop_vars(list|tuple): A list or tuple of tensors or LoDTensorArrays that is passed to both ``cond`` and ``body`` .
936
937
is_test(bool, optional): A flag indicating whether execution is in test phase. Default value is False.
937
938
name(str, optional): Normally there is no need for users to set this property. For more information, please
938
939
refer to :ref:`api_guide_Name`. Default is None.
939
940
940
941
Returns:
941
- A list or tuple of tensors which returned by ``body`` .
942
+ A list or tuple of tensors or LoDTensorArrays which returned by ``body`` .
942
943
943
944
Returen type:
944
945
list(Variable)|tuple(Variable).
@@ -951,29 +952,31 @@ def while_loop(cond, body, loop_vars, is_test=False, name=None):
951
952
TypeError: If the type of ``cond`` returns is not a boolean variable.
952
953
TypeError: If the shape of ``cond`` returns is not equals 1.
953
954
ValueError: If the ``var_loops`` is empty.
955
+ ValueError: If the length or type of ``body`` returns is not same as ``loop_vars``.
954
956
955
957
Examples:
956
958
.. code-block:: python
957
959
958
960
import paddle.fluid as fluid
959
961
import paddle.fluid.layers as layers
960
962
961
- def cond(i):
962
- return layers.less_than(i, ten)
963
+ def cond(i, ten ):
964
+ return i < ten
963
965
964
- def body(i):
965
- return layers.increment(x=i, value=1, in_place=True)
966
+ def body(i, ten):
967
+ i = i + 1
968
+ return [i, ten]
966
969
967
970
main_program = fluid.default_main_program()
968
971
startup_program = fluid.default_startup_program()
969
972
970
973
with fluid.program_guard(main_program, startup_program):
971
974
i = layers.fill_constant(shape=[1], dtype='int64', value=0) # loop counter
972
975
ten = layers.fill_constant(shape=[1], dtype='int64', value=10) # loop length
973
- out = layers.while_loop(cond, body, [i])
976
+ i, ten = layers.while_loop(cond, body, [i, ten ])
974
977
975
978
exe = fluid.Executor(fluid.CPUPlace())
976
- res = exe.run(main_program, feed={}, fetch_list=out )
979
+ res = exe.run(main_program, feed={}, fetch_list=[i] )
977
980
print(res) # [array([10])]
978
981
"""
979
982
helper = LayerHelper ('while_loop' , ** locals ())
@@ -1000,11 +1003,13 @@ def body(i):
1000
1003
while_loop_block = While (pre_cond , is_test , name )
1001
1004
with while_loop_block .block ():
1002
1005
output_vars = body (* loop_vars )
1006
+ if not isinstance (output_vars , (list , tuple )):
1007
+ output_vars = [output_vars ]
1008
+ if len (output_vars ) != len (loop_vars ):
1009
+ raise ValueError ("body in while_loop should return the same arity "
1010
+ "(length and structure) and types as loop_vars" )
1011
+ now_cond = cond (* output_vars )
1003
1012
map_structure (assign , output_vars , loop_vars )
1004
- if len (loop_vars ) == 1 :
1005
- now_cond = cond (output_vars )
1006
- else :
1007
- now_cond = cond (* output_vars )
1008
1013
assign (now_cond , pre_cond )
1009
1014
return loop_vars
1010
1015
0 commit comments