Skip to content

Commit b8eb52f

Browse files
fix the fp16 support of assgin and squeeze Op test=develop (#24939)
1 parent dc7d34e commit b8eb52f

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

python/paddle/fluid/layers/nn.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7479,7 +7479,7 @@ def squeeze(input, axes, name=None):
74797479
Out.shape = [1,3,5]
74807480

74817481
Args:
7482-
input (Variable): The input Tensor. Support data type: float32, float64, int8, int32, int64.
7482+
input (Variable): The input Tensor. Support data type: float16, float32, float64, int8, int32, int64.
74837483
axes (list): One integer or List of integers, indicating the dimensions to be squeezed.
74847484
Axes range is :math:`[-rank(input), rank(input))`.
74857485
If axes is negative, :math:`axes=axes+rank(input)`.
@@ -7499,9 +7499,9 @@ def squeeze(input, axes, name=None):
74997499

75007500
"""
75017501
helper = LayerHelper("squeeze", **locals())
7502-
check_variable_and_dtype(input, 'input',
7503-
['float32', 'float64', 'int8', 'int32', 'int64'],
7504-
'squeeze')
7502+
check_variable_and_dtype(
7503+
input, 'input',
7504+
['float16', 'float32', 'float64', 'int8', 'int32', 'int64'], 'squeeze')
75057505
check_type(axes, 'axes', list, 'squeeze')
75067506
out = helper.create_variable_for_type_inference(dtype=input.dtype)
75077507
x_shape = helper.create_variable_for_type_inference(dtype=input.dtype)

python/paddle/fluid/layers/tensor.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ def assign(input, output=None):
566566
567567
Parameters:
568568
input (Variable|numpy.ndarray): A tensor or numpy ndarray, its data type supports
569-
float32, float64, int32 and int64.
569+
float16, float32, float64, int32 and int64.
570570
output (Variable, optional): A tensor. If :attr:`output` is None, a new tensor will
571571
be created as :attr:`output`. Default: None.
572572
@@ -587,9 +587,10 @@ def assign(input, output=None):
587587
helper = LayerHelper('assign', **locals())
588588
check_type(input, 'input', (Variable, numpy.ndarray), 'assign')
589589
if isinstance(input, Variable):
590-
check_dtype(input.dtype, 'input',
591-
['float32', 'float64', 'int32', 'int64', 'bool'], 'assign',
592-
'(When the type of input in assign is Variable.)')
590+
check_dtype(
591+
input.dtype, 'input',
592+
['float16', 'float32', 'float64', 'int32', 'int64', 'bool'],
593+
'assign', '(When the type of input in assign is Variable.)')
593594
if output is None:
594595
output = helper.create_variable_for_type_inference(
595596
dtype=input.dtype)

0 commit comments

Comments
 (0)