Skip to content

Commit e327ca9

Browse files
zhengshengningmaxiaolong001
authored andcommitted
[API compatibility] add Alias : paddle.unique_consecutive, paddle.embedding, paddle. ones_like, paddle.repeat_interleave, paddle.var, paddle. take_along_axis (PaddlePaddle#74490)
* add alias : unique_consecutive, embedding * add alias : ones_like, repeat_interleave, var
1 parent 6344bfb commit e327ca9

File tree

8 files changed

+70
-0
lines changed

8 files changed

+70
-0
lines changed

python/paddle/nn/functional/input.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import paddle
1919
from paddle import _C_ops
20+
from paddle.utils.decorator_utils import ParamAliasDecorator
2021

2122
from ...base.data_feeder import check_variable_and_dtype
2223
from ...base.layer_helper import LayerHelper
@@ -161,6 +162,7 @@ def embedding_renorm_(
161162
return weight
162163

163164

165+
@ParamAliasDecorator({"x": ["input"]})
164166
def embedding(
165167
x: Tensor,
166168
weight: Tensor,

python/paddle/tensor/creation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1327,6 +1327,7 @@ def ones(
13271327
)
13281328

13291329

1330+
@ParamAliasDecorator({"x": ["input"]})
13301331
def ones_like(
13311332
x: paddle.Tensor,
13321333
dtype: DTypeLike | None = None,

python/paddle/tensor/manipulation.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3463,6 +3463,7 @@ def squeeze_(
34633463
return _C_ops.squeeze_(input, axes)
34643464

34653465

3466+
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
34663467
def unique_consecutive(
34673468
x: Tensor,
34683469
return_inverse: bool = False,
@@ -6288,6 +6289,7 @@ def as_real(x: Tensor, name: str | None = None) -> Tensor:
62886289
return out
62896290

62906291

6292+
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
62916293
def repeat_interleave(
62926294
x: Tensor,
62936295
repeats: int | Tensor,
@@ -6690,6 +6692,7 @@ def infer_broadcast_shape(
66906692
return broadcast_shape
66916693

66926694

6695+
@ParamAliasDecorator({"arr": ["input"], "axis": ["dim"]})
66936696
def take_along_axis(
66946697
arr: Tensor, indices: Tensor, axis: int, broadcast: bool = True
66956698
) -> Tensor:

python/paddle/tensor/stat.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
in_dynamic_mode,
2626
in_dynamic_or_pir_mode,
2727
)
28+
from paddle.utils.decorator_utils import ParamAliasDecorator
2829

2930
from ..base.data_feeder import check_type, check_variable_and_dtype
3031
from ..common_ops_import import Variable
@@ -149,6 +150,7 @@ def mean(
149150
return out
150151

151152

153+
@ParamAliasDecorator({"x": ["input"], "axis": ["dim"]})
152154
def var(
153155
x: Tensor,
154156
axis: int | Sequence[int] | None = None,

test/auto_parallel/semi_auto_parallel_for_embedding.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ def test_body(self, x_shape, w_shape, x_placements, w_placements):
5757
dist_out.backward()
5858
self.check_tensor_eq(w.grad, dist_w.grad)
5959

60+
out = paddle.nn.functional.embedding(input=x, weight=w)
61+
dist_out = paddle.nn.functional.embedding(input=dist_x, weight=dist_w)
62+
self.check_tensor_eq(out, dist_out)
63+
64+
out.backward()
65+
dist_out.backward()
66+
self.check_tensor_eq(w.grad, dist_w.grad)
67+
6068
return dist_out, dist_w.grad
6169

6270
def test_non_shard(self):

test/legacy_test/test_unique_consecutive_op.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -232,6 +232,13 @@ def test_dygraph(self):
232232
x = paddle.to_tensor(input_x)
233233
result = paddle.unique_consecutive(x)
234234

235+
def test_dygraph_alias(self):
236+
for place in self.places:
237+
with base.dygraph.guard(place):
238+
input_x = np.random.randint(20, size=100).astype("float64")
239+
x = paddle.to_tensor(input_x)
240+
result = paddle.unique_consecutive(input=x)
241+
235242

236243
class TestUniqueConsecutiveCase2API(unittest.TestCase):
237244
def setUp(self):
@@ -299,9 +306,32 @@ def check_static_result(self, place):
299306
fetch_list=[result],
300307
)
301308

309+
def check_static_result_alias(self, place):
310+
with paddle.static.program_guard(
311+
paddle.static.Program(), paddle.static.Program()
312+
):
313+
paddle.enable_static()
314+
input_x = paddle.static.data(
315+
name="input_x",
316+
shape=[
317+
100,
318+
],
319+
dtype="float32",
320+
)
321+
result, inverse, counts = paddle.unique_consecutive(
322+
input=input_x, return_inverse=True, return_counts=True, axis=-1
323+
)
324+
x_np = np.random.randint(20, size=100).astype("float32")
325+
exe = base.Executor(place)
326+
fetches = exe.run(
327+
feed={"input_x": x_np},
328+
fetch_list=[result],
329+
)
330+
302331
def test_static(self):
303332
for place in self.places:
304333
self.check_static_result(place=place)
334+
self.check_static_result_alias(place=place)
305335

306336
def test_dygraph(self):
307337
for place in self.places:

test/legacy_test/test_zero_dim_no_backward_api.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,18 @@ def test_embedding(self):
175175
for i in range(len(res)):
176176
self.assertEqual(emb.numpy()[i], res[i])
177177

178+
def test_embedding_alias(self):
179+
ids = paddle.full(shape=[], fill_value=1, dtype='int64')
180+
w0 = paddle.arange(3, 9).reshape((3, 2)).astype(paddle.float32)
181+
w = paddle.to_tensor(w0, stop_gradient=False)
182+
emb = paddle.nn.functional.embedding(
183+
input=ids, weight=w, sparse=True, name="embedding"
184+
)
185+
self.assertEqual(emb.shape, [2])
186+
res = [5.0, 6.0]
187+
for i in range(len(res)):
188+
self.assertEqual(emb.numpy()[i], res[i])
189+
178190
def test_one_hot_label(self):
179191
label = paddle.full(shape=[], fill_value=2, dtype='int64')
180192
one_hot_label = paddle.nn.functional.one_hot(label, num_classes=4)

test/xpu/test_zero_dim_tensor_xpu.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2673,6 +2673,18 @@ def test_embedding(self):
26732673
for i in range(len(res)):
26742674
self.assertEqual(emb.numpy()[i], res[i])
26752675

2676+
def test_embedding_alias(self):
2677+
ids = paddle.full(shape=[], fill_value=1, dtype='int64')
2678+
w0 = paddle.arange(3, 9).reshape((3, 2)).astype(paddle.float32)
2679+
w = paddle.to_tensor(w0, stop_gradient=False)
2680+
emb = paddle.nn.functional.embedding(
2681+
input=ids, weight=w, sparse=True, name="embedding"
2682+
)
2683+
self.assertEqual(emb.shape, [2])
2684+
res = [5.0, 6.0]
2685+
for i in range(len(res)):
2686+
self.assertEqual(emb.numpy()[i], res[i])
2687+
26762688
def test_one_hot_label(self):
26772689
label = paddle.full(shape=[], fill_value=2, dtype='int64')
26782690
one_hot_label = paddle.nn.functional.one_hot(label, num_classes=4)

0 commit comments

Comments
 (0)