Skip to content

Commit 9f97d37

Browse files
authored
[API compatibility] torch.as_tensor, torch.finfo, torch.is_complex, torch.nn.functional.pad (#74456)
* [API compatibility] paddle.to_tensor * choose suitable place * fix annotation * fix whitespace * fix getfullargspec * torch.as_tensor, torch.finfo, torch.is_complex, torch.nn.functional.pad * fix type_check * deleta example
1 parent 9df3eef commit 9f97d37

File tree

9 files changed

+200
-24
lines changed

9 files changed

+200
-24
lines changed

python/paddle/framework/dtype.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import paddle
16+
from paddle.utils.decorator_utils import ParamAliasDecorator
1617

1718
from ..base import framework
1819
from ..base.core import (
@@ -205,16 +206,22 @@ def iinfo(dtype):
205206
return core_iinfo(dtype)
206207

207208

209+
@ParamAliasDecorator({"dtype": ["type"]})
208210
def finfo(dtype):
209211
"""
210212
211213
``paddle.finfo`` is a function that returns an object that represents the numerical properties of a floating point
212214
``paddle.dtype``.
213215
This is similar to `numpy.finfo <https://numpy.org/doc/stable/reference/generated/numpy.finfo.html#numpy-finfo>`_.
214216
217+
.. note::
218+
Alias Support: The parameter name ``type`` can be used as an alias for ``dtype``.
219+
For example, ``type=paddle.float32`` is equivalent to ``type=paddle.float32``.
220+
215221
Args:
216222
dtype(paddle.dtype|string): One of ``paddle.float16``, ``paddle.float32``, ``paddle.float64``, ``paddle.bfloat16``,
217223
``paddle.complex64``, and ``paddle.complex128``.
224+
type: An alias for ``dtype`` , with identical behavior.
218225
219226
Returns:
220227
An ``finfo`` object, which has the following 8 attributes:

python/paddle/nn/functional/common.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from paddle.tensor.creation import full
3434
from paddle.utils import deprecated
35+
from paddle.utils.decorator_utils import ParamAliasDecorator
3536
from paddle.utils.layers_utils import NotSupportedTensorArgumentError
3637

3738
from ...base.data_feeder import (
@@ -1857,6 +1858,7 @@ def feature_alpha_dropout(
18571858
)
18581859

18591860

1861+
@ParamAliasDecorator({"x": ["input"]})
18601862
def pad(
18611863
x: Tensor,
18621864
pad: ShapeLike,
@@ -1892,8 +1894,14 @@ def pad(
18921894
4. If mode is ``'reflect'``, pad[0] and pad[1] must be no greater than width-1. The height and depth
18931895
dimension has the same condition.
18941896
1897+
.. note::
1898+
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
1899+
For example, ``input=tensor_x`` is equivalent to ``x=tensor_x``.
1900+
1901+
18951902
Args:
18961903
x (Tensor): The input tensor with data type float32, float64, int32, int64, complex64 or complex128.
1904+
input: An alias for ``x`` , with identical behavior.
18971905
pad (Tensor|list[int]|tuple[int]): The padding size with data type int. Refer to Note for details.
18981906
mode (str, optional): Four modes: ``'constant'`` (default), ``'reflect'``, ``'replicate'``, ``'circular'``. Default is ``'constant'``.
18991907

python/paddle/tensor/attribute.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
import paddle
2222
from paddle import _C_ops
23+
from paddle.utils.decorator_utils import ParamAliasDecorator
2324

2425
from ..base.data_feeder import check_type, check_variable_and_dtype
2526
from ..base.framework import in_dynamic_or_pir_mode, use_pir_api
@@ -144,11 +145,18 @@ def shape(input: Tensor) -> Tensor:
144145
return out
145146

146147

148+
@ParamAliasDecorator({"x": ["input"]})
147149
def is_complex(x: Tensor) -> bool:
148150
"""Return whether x is a tensor of complex data type(complex64 or complex128).
149151
152+
153+
.. note::
154+
Alias Support: The parameter name ``input`` can be used as an alias for ``x``.
155+
For example, ``input=tensor_x`` is equivalent to ``x=tensor_x``.
156+
150157
Args:
151158
x (Tensor): The input tensor.
159+
input: An alias for ``x`` , with identical behavior.
152160
153161
Returns:
154162
bool: True if the data type of the input is complex data type, otherwise false.

python/paddle/tensor/creation.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import paddle
2626
from paddle import _C_ops
27+
from paddle.utils.decorator_utils import ParamAliasDecorator
2728
from paddle.utils.inplace_utils import inplace_apis_in_dygraph_only
2829

2930
from ..base.data_feeder import (
@@ -876,6 +877,7 @@ def _to_tensor_static(
876877
return output
877878

878879

880+
@ParamAliasDecorator({"place": ["device"]})
879881
def to_tensor(
880882
data: TensorLike | NestedNumericSequence,
881883
dtype: DTypeLike | None = None,
@@ -889,6 +891,10 @@ def to_tensor(
889891
If the ``data`` is already a Tensor, copy will be performed and return a new tensor.
890892
If you only want to change stop_gradient property, please call ``Tensor.stop_gradient = stop_gradient`` directly.
891893
894+
.. note::
895+
Alias Support: The parameter name ``device`` can be used as an alias for ``place``.
896+
For example, ``device=paddle.CUDAPlace(0)`` is equivalent to ``place=paddle.CUDAPlace(0)``.
897+
892898
.. code-block:: text
893899
894900
We use the dtype conversion rules following this:
@@ -911,6 +917,7 @@ def to_tensor(
911917
place(CPUPlace|CUDAPinnedPlace|CUDAPlace|str, optional): The place to allocate Tensor. Can be
912918
CPUPlace, CUDAPinnedPlace, CUDAPlace. Default: None, means global place. If ``place`` is
913919
string, It can be ``cpu``, ``gpu:x`` and ``gpu_pinned``, where ``x`` is the index of the GPUs.
920+
device: An alias for ``place`` , with identical behavior.
914921
stop_gradient(bool, optional): Whether to block the gradient propagation of Autograd. Default: True.
915922
916923
Returns:

python/paddle/utils/decorator_utils.py

Lines changed: 8 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,15 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from __future__ import annotations
16-
1715
import functools
1816
import inspect
19-
from typing import (
20-
TYPE_CHECKING,
21-
Any,
22-
Callable,
23-
Generic,
24-
TypeVar,
25-
cast,
26-
)
27-
28-
from typing_extensions import ParamSpec
29-
30-
if TYPE_CHECKING:
31-
from collections.abc import Iterable
32-
17+
from collections.abc import Iterable
18+
from typing import Any, Callable, TypeVar, cast
3319

34-
_P = ParamSpec("_P")
35-
_R = TypeVar("_R")
36-
_DecoratedFunc = Callable[_P, _R]
20+
_F = TypeVar("_F", bound=Callable[..., Any])
3721

3822

39-
class DecoratorBase(Generic[_P, _R]):
23+
class DecoratorBase:
4024
"""Decorative base class, providing a universal decorative framework.
4125
4226
Subclass only needs to implement the 'process' method to define the core logic.
@@ -47,19 +31,19 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
4731
self.args = args
4832
self.kwargs = kwargs
4933

50-
def __call__(self, func: _DecoratedFunc[_P, _R]) -> _DecoratedFunc[_P, _R]:
34+
def __call__(self, func: _F) -> _F:
5135
"""As an entry point for decorative applications"""
5236

5337
@functools.wraps(func)
54-
def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
38+
def wrapper(*args, **kwargs):
5539
# Pretreatment parameters
5640
processed_args, processed_kwargs = self.process(args, kwargs)
5741
# Call the original function
5842
return func(*processed_args, **processed_kwargs)
5943

6044
# Keep original signature
6145
wrapper.__signature__ = inspect.signature(func)
62-
return cast("_DecoratedFunc[_P, _R]", wrapper)
46+
return cast("_F", wrapper)
6347

6448
def process(
6549
self, args: tuple[Any, ...], kwargs: dict[str, Any]
@@ -77,7 +61,7 @@ def process(
7761

7862

7963
# Example implementation: Parameter alias decorator
80-
class ParamAliasDecorator(DecoratorBase[_P, _R]):
64+
class ParamAliasDecorator(DecoratorBase):
8165
"""Implementation of Decorator for Parameter Alias Processing"""
8266

8367
def __init__(self, alias_mapping: dict[str, Iterable[str]]) -> None:

test/legacy_test/test_eager_tensor.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,52 @@ def test_to_tensor_attributes(self):
377377
self.assertEqual(var.dtype, paddle.float32)
378378
self.assertEqual(var.type, core.VarDesc.VarType.DENSE_TENSOR)
379379

380+
def test_to_tensor_param_alias(self):
381+
"""Test paddle.to_tensor parameter mapping ("place": ["device"])."""
382+
# 1. Test equivalence of place and device parameters
383+
tensor_place = paddle.to_tensor(self.array, place=paddle.CPUPlace())
384+
tensor_device = paddle.to_tensor(self.array, device=paddle.CPUPlace())
385+
386+
np.testing.assert_array_equal(
387+
tensor_device.numpy(), tensor_place.numpy()
388+
)
389+
self.assertEqual(tensor_device.place, tensor_place.place)
390+
391+
# 2. Test conflict between place and device (should raise KeyError)
392+
with self.assertRaises(ValueError) as context:
393+
paddle.to_tensor(
394+
self.array,
395+
place=paddle.CPUPlace(),
396+
device=paddle.CPUPlace(), # Conflict
397+
)
398+
self.assertIn(
399+
"Cannot specify both 'place' and its alias 'device'",
400+
str(context.exception),
401+
)
402+
403+
# 3. Test dtype and stop_gradient consistency
404+
tensor1 = paddle.to_tensor(
405+
self.array, dtype="float32", device=paddle.CPUPlace()
406+
)
407+
tensor2 = paddle.to_tensor(
408+
self.array, dtype="float32", place=paddle.CPUPlace()
409+
)
410+
411+
self.assertEqual(tensor1.dtype, tensor2.dtype)
412+
self.assertEqual(tensor1.dtype, paddle.float32)
413+
self.assertTrue(tensor1.stop_gradient)
414+
self.assertEqual(tensor1.stop_gradient, tensor2.stop_gradient)
415+
416+
# 4. Test cross-device compatibility (CPU/GPU)
417+
for device in [paddle.CPUPlace()] + (
418+
[paddle.CUDAPlace(0)] if core.is_compiled_with_cuda() else []
419+
):
420+
tensor_device = paddle.to_tensor(self.array, device=device)
421+
tensor_place = paddle.to_tensor(self.array, place=device)
422+
423+
self.assertEqual(tensor_device.place, tensor_place.place)
424+
self.assertEqual(tensor_device.place, device)
425+
380426
def test_list_to_tensor(self):
381427
array = [[[1, 2], [1, 2], [1.0, 2]], [[1, 2], [1, 2], [1, 2]]]
382428
var = paddle.to_tensor(array, dtype="int32")

test/legacy_test/test_iinfo_and_finfo.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,49 @@ def test_finfo(self):
135135
self.assertAlmostEqual(xinfo.resolution, 0.01)
136136
self.assertAlmostEqual(xinfo.smallest_normal, 1.1754943508222875e-38)
137137

138+
def test_finfo_alias(self):
139+
# dtype and type alias
140+
for alias_param in ["dtype", "type"]:
141+
for paddle_dtype, np_dtype in [
142+
(paddle.float32, np.float32),
143+
(paddle.float64, np.float64),
144+
('float32', np.float32),
145+
('float64', np.float64),
146+
]:
147+
xinfo = paddle.finfo(**{alias_param: paddle_dtype})
148+
xninfo = np.finfo(np_dtype)
149+
self.assertEqual(xinfo.dtype, xninfo.dtype)
150+
self.assertEqual(xinfo.bits, xninfo.bits)
151+
self.assertAlmostEqual(xinfo.max, xninfo.max)
152+
self.assertAlmostEqual(xinfo.min, xninfo.min)
153+
self.assertAlmostEqual(xinfo.eps, xninfo.eps)
154+
self.assertAlmostEqual(xinfo.tiny, xninfo.tiny)
155+
self.assertAlmostEqual(xinfo.resolution, xninfo.resolution)
156+
if np.lib.NumpyVersion(np.__version__) >= "1.22.0":
157+
self.assertAlmostEqual(
158+
xinfo.smallest_normal, xninfo.smallest_normal
159+
)
160+
161+
for paddle_dtype, np_dtype in [
162+
(paddle.complex64, np.complex64),
163+
(paddle.complex128, np.complex128),
164+
('complex64', np.complex64),
165+
('complex128', np.complex128),
166+
]:
167+
xinfo = paddle.finfo(**{alias_param: paddle_dtype})
168+
xninfo = np.finfo(np_dtype)
169+
self.assertEqual(xinfo.dtype, xninfo.dtype)
170+
self.assertEqual(xinfo.bits, xninfo.bits)
171+
self.assertAlmostEqual(xinfo.max, xninfo.max, places=16)
172+
self.assertAlmostEqual(xinfo.min, xninfo.min, places=16)
173+
self.assertAlmostEqual(xinfo.eps, xninfo.eps, places=16)
174+
self.assertAlmostEqual(xinfo.tiny, xninfo.tiny, places=16)
175+
self.assertAlmostEqual(xinfo.resolution, xninfo.resolution)
176+
if np.lib.NumpyVersion(np.__version__) >= "1.22.0":
177+
self.assertAlmostEqual(
178+
xinfo.smallest_normal, xninfo.smallest_normal, places=16
179+
)
180+
138181

139182
if __name__ == '__main__':
140183
unittest.main()

test/legacy_test/test_is_complex.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,21 @@ def test_for_exception(self):
3636
with self.assertRaises(TypeError):
3737
paddle.is_complex(np.array([1, 2]))
3838

39+
def test_for_alias(self):
40+
for alias_param in ["x", "input"]:
41+
# test_for_integer
42+
x = paddle.arange(10)
43+
self.assertFalse(paddle.is_complex(**{alias_param: x}))
44+
# test_for_floating_point
45+
x = paddle.randn([2, 3])
46+
self.assertFalse(paddle.is_complex(**{alias_param: x}))
47+
# test_for_complex
48+
x = paddle.randn([2, 3]) + 1j * paddle.randn([2, 3])
49+
self.assertTrue(paddle.is_complex(**{alias_param: x}))
50+
# test_for_exception
51+
with self.assertRaises(TypeError):
52+
paddle.is_complex(**{alias_param: np.array([1, 2])})
53+
3954

4055
if __name__ == '__main__':
4156
unittest.main()

test/legacy_test/test_pad_op.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,64 @@ def init_case(self):
608608
self.pad_value = 0.5
609609

610610

611+
class TestPadAliasSupport(unittest.TestCase):
612+
def setUp(self):
613+
paddle.disable_static()
614+
self.shape = (2, 3)
615+
self.paddings = [1, 2, 3, 4]
616+
self.value = 0.5
617+
self.x = np.random.random(self.shape).astype('float32')
618+
619+
def test_no_param_name(self):
620+
out = paddle.nn.functional.pad(
621+
paddle.to_tensor(self.x), self.paddings, value=self.value
622+
)
623+
expected = np.pad(
624+
self.x,
625+
[(1, 2), (3, 4)],
626+
mode='constant',
627+
constant_values=self.value,
628+
)
629+
np.testing.assert_array_equal(out.numpy(), expected)
630+
631+
def test_x_param_name(self):
632+
out = paddle.nn.functional.pad(
633+
x=paddle.to_tensor(self.x), pad=self.paddings, value=self.value
634+
)
635+
expected = np.pad(
636+
self.x,
637+
[(1, 2), (3, 4)],
638+
mode='constant',
639+
constant_values=self.value,
640+
)
641+
np.testing.assert_array_equal(out.numpy(), expected)
642+
643+
def test_input_param_name(self):
644+
out = paddle.nn.functional.pad(
645+
input=paddle.to_tensor(self.x), pad=self.paddings, value=self.value
646+
)
647+
expected = np.pad(
648+
self.x,
649+
[(1, 2), (3, 4)],
650+
mode='constant',
651+
constant_values=self.value,
652+
)
653+
np.testing.assert_array_equal(out.numpy(), expected)
654+
655+
def test_both_param_name(self):
656+
with self.assertRaises(ValueError) as context:
657+
paddle.nn.functional.pad(
658+
x=paddle.to_tensor(self.x),
659+
input=paddle.to_tensor(self.x),
660+
pad=self.paddings,
661+
value=self.value,
662+
)
663+
self.assertIn(
664+
"Cannot specify both 'x' and its alias 'input'",
665+
str(context.exception),
666+
)
667+
668+
611669
if __name__ == "__main__":
612670
# paddle.enable_static()
613671
unittest.main()

0 commit comments

Comments
 (0)