Skip to content

Commit a03e0ef

Browse files
authored
ENH: Improve performance of np.broadcast_arrays and np.broadcast_shapes (numpy#26160)
* ENH: Improve performance of np.broadcast_arrays * modify tests * lint * whitespace * lint * improve performance of broadcast_shapes
1 parent e2fb336 commit a03e0ef

File tree

2 files changed

+17
-16
lines changed

2 files changed

+17
-16
lines changed

numpy/lib/_stride_tricks_impl.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ def broadcast_shapes(*args):
478478
>>> np.broadcast_shapes((6, 7), (5, 6, 1), (7,), (5, 1, 7))
479479
(5, 6, 7)
480480
"""
481-
arrays = [np.empty(x, dtype=[]) for x in args]
481+
arrays = [np.empty(x, dtype=bool) for x in args]
482482
return _broadcast_shape(*arrays)
483483

484484

@@ -546,13 +546,12 @@ def broadcast_arrays(*args, subok=False):
546546
# return np.nditer(args, flags=['multi_index', 'zerosize_ok'],
547547
# order='C').itviews
548548

549-
args = tuple(np.array(_m, copy=None, subok=subok) for _m in args)
549+
args = [np.array(_m, copy=None, subok=subok) for _m in args]
550550

551551
shape = _broadcast_shape(*args)
552552

553-
if all(array.shape == shape for array in args):
554-
# Common case where nothing needs to be broadcasted.
555-
return args
553+
result = [array if array.shape == shape
554+
else _broadcast_to(array, shape, subok=subok, readonly=False)
555+
for array in args]
556+
return tuple(result)
556557

557-
return tuple(_broadcast_to(array, shape, subok=subok, readonly=False)
558-
for array in args)

numpy/lib/tests/test_stride_tricks.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,7 +341,7 @@ def test_broadcast_shapes_raises():
341341
[(2, 3), (2,)],
342342
[(3,), (3,), (4,)],
343343
[(1, 3, 4), (2, 3, 3)],
344-
[(1, 2), (3,1), (3,2), (10, 5)],
344+
[(1, 2), (3, 1), (3, 2), (10, 5)],
345345
[2, (2, 3)],
346346
]
347347
for input_shapes in data:
@@ -578,11 +578,12 @@ def test_writeable():
578578

579579
# but the result of broadcast_arrays needs to be writeable, to
580580
# preserve backwards compatibility
581-
for is_broadcast, results in [(False, broadcast_arrays(original,)),
582-
(True, broadcast_arrays(0, original))]:
583-
for result in results:
581+
test_cases = [((False,), broadcast_arrays(original,)),
582+
((True, False), broadcast_arrays(0, original))]
583+
for is_broadcast, results in test_cases:
584+
for array_is_broadcast, result in zip(is_broadcast, results):
584585
# This will change to False in a future version
585-
if is_broadcast:
586+
if array_is_broadcast:
586587
with assert_warns(FutureWarning):
587588
assert_equal(result.flags.writeable, True)
588589
with assert_warns(DeprecationWarning):
@@ -623,11 +624,12 @@ def test_writeable_memoryview():
623624
# See gh-13929.
624625
original = np.array([1, 2, 3])
625626

626-
for is_broadcast, results in [(False, broadcast_arrays(original,)),
627-
(True, broadcast_arrays(0, original))]:
628-
for result in results:
627+
test_cases = [((False, ), broadcast_arrays(original,)),
628+
((True, False), broadcast_arrays(0, original))]
629+
for is_broadcast, results in test_cases:
630+
for array_is_broadcast, result in zip(is_broadcast, results):
629631
# This will change to False in a future version
630-
if is_broadcast:
632+
if array_is_broadcast:
631633
# memoryview(result, writable=True) will give warning but cannot
632634
# be tested using the python API.
633635
assert memoryview(result).readonly

0 commit comments

Comments
 (0)