Skip to content

Commit 0a80351

Browse files
authored
[API Compatiblity] add mean, Tensor.mean (#74955)
* resove conflicts * update test, enhance performance * add backward tests, use cast for non-inplace op * add tests
1 parent a0708e0 commit 0a80351

File tree

4 files changed

+858
-12
lines changed

4 files changed

+858
-12
lines changed

python/paddle/tensor/math.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4483,7 +4483,7 @@ def cumprod(
44834483
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
44844484
dtype = convert_np_dtype_to_dtype_(dtype)
44854485
if x.dtype != dtype:
4486-
x = cast_(x, dtype)
4486+
x = cast(x, dtype)
44874487

44884488
if in_dynamic_or_pir_mode():
44894489
return _C_ops.cumprod(x, dim, False, False)
@@ -4530,9 +4530,7 @@ def cumprod_(
45304530
if dim is None:
45314531
dim = -1
45324532
x = _C_ops.flatten_(x, 0, len(x.shape) - 1)
4533-
if dtype is None:
4534-
dtype = x.dtype
4535-
else:
4533+
if dtype is not None:
45364534
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
45374535
dtype = convert_np_dtype_to_dtype_(dtype)
45384536
if x.dtype != dtype:

python/paddle/tensor/stat.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -27,40 +27,45 @@
2727
)
2828
from paddle.utils.decorator_utils import (
2929
ParamAliasDecorator,
30+
param_two_alias,
3031
param_two_alias_one_default,
3132
)
3233

3334
from ..base.data_feeder import check_type, check_variable_and_dtype
3435
from ..common_ops_import import Variable
35-
from ..framework import (
36-
LayerHelper,
37-
core,
38-
)
36+
from ..framework import LayerHelper, convert_np_dtype_to_dtype_, core
37+
from .manipulation import cast
3938
from .math import _get_reduce_axis_with_tensor
4039

4140
if TYPE_CHECKING:
4241
from collections.abc import Sequence
4342

4443
from paddle import Tensor
44+
from paddle._typing import DTypeLike
4545

4646
_Interpolation: TypeAlias = Literal[
4747
'linear', 'higher', 'lower', 'midpoint', 'nearest'
4848
]
4949
__all__ = []
5050

5151

52+
@param_two_alias(["x", "input"], ["axis", "dim"])
5253
def mean(
5354
x: Tensor,
5455
axis: int | Sequence[int] | None = None,
5556
keepdim: bool = False,
5657
name: str | None = None,
58+
*,
59+
dtype: DTypeLike | None = None,
60+
out: Tensor | None = None,
5761
) -> Tensor:
5862
"""
5963
Computes the mean of the input tensor's elements along ``axis``.
6064
6165
Args:
6266
x (Tensor): The input Tensor with data type bool, bfloat16, float16, float32,
6367
float64, int32, int64, complex64, complex128.
68+
alias: ``input``
6469
axis (int|list|tuple|None, optional): The axis along which to perform mean
6570
calculations. ``axis`` should be int, list(int) or tuple(int). If
6671
``axis`` is a list/tuple of dimension(s), mean is calculated along
@@ -69,13 +74,16 @@ def mean(
6974
``axis`` or element(s) of ``axis`` is less than 0, it works the
7075
same way as :math:`axis + D` . If ``axis`` is None, mean is
7176
calculated over all elements of ``x``. Default is None.
77+
alias: ``dim``
7278
keepdim (bool, optional): Whether to reserve the reduced dimension(s)
7379
in the output Tensor. If ``keepdim`` is True, the dimensions of
7480
the output Tensor is the same as ``x`` except in the reduced
7581
dimensions(it is of size 1 in this case). Otherwise, the shape of
7682
the output Tensor is squeezed in ``axis`` . Default is False.
7783
name (str|None, optional): Name for the operation (optional, default is None).
7884
For more information, please refer to :ref:`api_guide_Name`.
85+
dtype (str): The desired data type of returned tensor. Default: None.
86+
out(Tensor|None, optional): The output tensor. Default: None.
7987
8088
Returns:
8189
Tensor, results of average along ``axis`` of ``x``, with the same data
@@ -110,9 +118,19 @@ def mean(
110118
>>> out4 = paddle.mean(x, axis=[0, 2])
111119
>>> print(out4.numpy())
112120
[ 8.5 12.5 16.5]
121+
>>> out5 = paddle.mean(x, dtype='float64')
122+
>>> out5
123+
Tensor(shape=[], dtype=float64, place=Place(gpu:0), stop_gradient=True,
124+
12.50000000)
113125
"""
126+
if dtype is not None:
127+
if not isinstance(dtype, (core.VarDesc.VarType, core.DataType)):
128+
dtype = convert_np_dtype_to_dtype_(dtype)
129+
if x.dtype != dtype:
130+
x = cast(x, dtype)
131+
114132
if in_dynamic_or_pir_mode():
115-
return _C_ops.mean(x, axis, keepdim)
133+
return _C_ops.mean(x, axis, keepdim, out=out)
116134
else:
117135
reduce_all, axis = _get_reduce_axis_with_tensor(axis, x)
118136
check_variable_and_dtype(
@@ -146,14 +164,14 @@ def mean(
146164
helper = LayerHelper('mean', **locals())
147165

148166
attrs = {'dim': axis, 'keep_dim': keepdim, 'reduce_all': reduce_all}
149-
out = helper.create_variable_for_type_inference(x.dtype)
167+
out_tensor = helper.create_variable_for_type_inference(x.dtype)
150168
helper.append_op(
151169
type='reduce_mean',
152170
inputs={'X': x},
153-
outputs={'Out': out},
171+
outputs={'Out': out_tensor},
154172
attrs=attrs,
155173
)
156-
return out
174+
return out_tensor
157175

158176

159177
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})

0 commit comments

Comments
 (0)