Skip to content

Commit d77dd90

Browse files
authored
[API compatibility] add broadcast_shapes api (#74594)
* add broadcast_shapes api * change judgement
1 parent 20e50fb commit d77dd90

File tree

4 files changed

+124
-0
lines changed

4 files changed

+124
-0
lines changed

python/paddle/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,7 @@
406406
bitwise_right_shift,
407407
bitwise_right_shift_,
408408
broadcast_shape,
409+
broadcast_shapes,
409410
cartesian_prod,
410411
ceil,
411412
clip,
@@ -1024,6 +1025,7 @@
10241025
'DataParallel',
10251026
'argmin',
10261027
'prod',
1028+
'broadcast_shapes',
10271029
'broadcast_shape',
10281030
'conj',
10291031
'neg',

python/paddle/tensor/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,7 @@
269269
bitwise_right_shift,
270270
bitwise_right_shift_,
271271
broadcast_shape,
272+
broadcast_shapes,
272273
cartesian_prod,
273274
ceil,
274275
ceil_,
@@ -638,6 +639,7 @@
638639
'isneginf',
639640
'isposinf',
640641
'isreal',
642+
'broadcast_shapes',
641643
'broadcast_shape',
642644
'conj',
643645
'neg',

python/paddle/tensor/math.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5448,6 +5448,53 @@ def any(
54485448
return out
54495449

54505450

5451+
def broadcast_shapes(*shapes: Sequence[int]) -> list[int]:
5452+
"""
5453+
The function returns the shape of doing operation with broadcasting on tensors of shape list.
5454+
5455+
Note:
5456+
If you want know more about broadcasting, please refer to `Introduction to Tensor`_ .
5457+
5458+
.. _Introduction to Tensor: ../../guides/beginner/tensor_en.html#chapter5-broadcasting-of-tensor
5459+
5460+
Args:
5461+
*shapes (list[int]|tuple[int]): A shape list of multiple tensors.
5462+
5463+
5464+
Returns:
5465+
list[int], the result shape.
5466+
5467+
Examples:
5468+
.. code-block:: python
5469+
5470+
>>> import paddle
5471+
5472+
>>> shape = paddle.broadcast_shapes([2, 1, 3], [1, 3, 1])
5473+
>>> shape
5474+
[2, 3, 3]
5475+
5476+
>>> # shape = paddle.broadcast_shapes([2, 1, 3], [3, 3, 1])
5477+
>>> # ValueError (terminated with error message).
5478+
5479+
>>> shape = paddle.broadcast_shapes([5, 1, 3], [1, 4, 1], [1, 1, 3])
5480+
>>> shape
5481+
[5, 4, 3]
5482+
5483+
>>> # shape = paddle.broadcast_shapes([5, 1, 3], [1, 4, 1], [1, 2, 3])
5484+
>>> # ValueError (terminated with error message).
5485+
5486+
"""
5487+
if len(shapes) == 0:
5488+
return []
5489+
elif len(shapes) == 1:
5490+
return list(shapes[0])
5491+
else:
5492+
current_shape = list(shapes[0])
5493+
for next_shape in shapes[1:]:
5494+
current_shape = broadcast_shape(current_shape, next_shape)
5495+
return current_shape
5496+
5497+
54515498
def broadcast_shape(
54525499
x_shape: Sequence[int], y_shape: Sequence[int]
54535500
) -> list[int]:
Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import unittest
16+
17+
import paddle
18+
19+
20+
class TestBroadcastShapes(unittest.TestCase):
21+
def test_result(self):
22+
shape = paddle.broadcast_shapes(
23+
[5, 1, 3, 10],
24+
[5, 4, 1, 1],
25+
[1, 1, 3, 10],
26+
[1, 4, 3, 1],
27+
[1, 4, 1, 10],
28+
)
29+
self.assertEqual(shape, [5, 4, 3, 10])
30+
31+
shape = paddle.broadcast_shapes([-1, 1, 3], [1, 6, 1], [1, 1, 3])
32+
self.assertEqual(shape, [-1, 6, 3])
33+
34+
shape = paddle.broadcast_shapes([8, 3])
35+
36+
self.assertEqual(shape, [8, 3])
37+
38+
shape = paddle.broadcast_shapes([2, 3, 1], [6], [3, 1])
39+
self.assertEqual(shape, [2, 3, 6])
40+
41+
def test_empty(self):
42+
shape = paddle.broadcast_shapes([])
43+
self.assertEqual(shape, [])
44+
45+
shape = paddle.broadcast_shapes([], [2, 3, 4])
46+
self.assertEqual(shape, [2, 3, 4])
47+
48+
shape = paddle.broadcast_shapes([10, 1, 7], [], [1, 6, 1], [1, 1, 7])
49+
self.assertEqual(shape, [10, 6, 7])
50+
51+
def test_complex_case(self):
52+
test_cases = [
53+
([0], [1], [], [0]),
54+
([2, -1], [0], [2, 0]),
55+
([0, 3], [3], [0, 3]),
56+
([0, 1, 3], [0, 1, 0, 3], [1, 0, -1], [0, 0, 0, 3]),
57+
([0, 1, 3], [0, 1, 1, 5, 3], [], [0, 1, 0, 5, 3]),
58+
]
59+
60+
for shape_list in test_cases:
61+
expected = shape_list[-1]
62+
result = paddle.broadcast_shapes(*shape_list[:-1])
63+
self.assertEqual(result, expected)
64+
65+
def test_error(self):
66+
self.assertRaises(
67+
ValueError, paddle.broadcast_shapes, [5, 1, 3], [1, 4, 1], [1, 2, 3]
68+
)
69+
self.assertRaises(ValueError, paddle.broadcast_shapes, [0], [0, 2])
70+
71+
72+
if __name__ == "__main__":
73+
unittest.main()

0 commit comments

Comments
 (0)