Skip to content

Commit a2747c7

Browse files
committed
NF: add strided_scalar function
Function to return scalar broadcast out to form full array of given shape.
1 parent 2b55e29 commit a2747c7

File tree

2 files changed

+42
-1
lines changed

2 files changed

+42
-1
lines changed

nibabel/fileslice.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -747,3 +747,26 @@ def fileslice(fileobj, sliceobj, shape, dtype, offset=0, order='C',
747747
bytes = read_segments(fileobj, segments, n_bytes)
748748
sliced = np.ndarray(sliced_shape, dtype, buffer=bytes, order=order)
749749
return sliced[post_slicers]
750+
751+
752+
def strided_scalar(shape, scalar=0.):
753+
""" Return array shape `shape` where all entries point to value `scalar`
754+
755+
Parameters
756+
----------
757+
shape : sequence
758+
Shape of output array.
759+
scalar : scalar
760+
Scalar value with which to fill array.
761+
762+
Returns
763+
-------
764+
strided_arr : array
765+
Array of shape `shape` for which all values == `scalar`, built by
766+
setting all strides of `strided_arr` to 0, so the scalar is broadcast
767+
out to the full array `shape`.
768+
"""
769+
shape = tuple(shape)
770+
scalar = np.array(scalar)
771+
strides = [0] * len(shape)
772+
return np.lib.stride_tricks.as_strided(scalar, shape, strides)

nibabel/tests/test_fileslice.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414
predict_shape, read_segments, _positive_slice,
1515
threshold_heuristic, optimize_slicer, slice2len,
1616
fill_slicer, optimize_read_slicers, slicers2segments,
17-
calc_slicedefs, _simple_fileslice, slice2outax)
17+
calc_slicedefs, _simple_fileslice, slice2outax,
18+
strided_scalar)
1819

1920
from nose.tools import assert_true, assert_false, assert_equal, assert_raises
2021

@@ -627,6 +628,23 @@ def test_predict_shape():
627628
assert_equal(predict_shape((1, slice(None), None), (2, 3)), (3, 1))
628629

629630

631+
def test_strided_scalar():
632+
# Utility to make numpy array of given shape from scalar using striding
633+
for shape, scalar in product(
634+
((2,), (2, 3,), (2, 3, 4)),
635+
(1, 2, np.int16(3))):
636+
expected = np.zeros(shape, dtype=np.array(scalar).dtype) + scalar
637+
observed = strided_scalar(shape, scalar)
638+
assert_array_equal(observed, expected)
639+
assert_equal(observed.shape, shape)
640+
assert_equal(observed.dtype, expected.dtype)
641+
assert_array_equal(observed.strides, 0)
642+
observed[..., 0] = 99
643+
assert_array_equal(observed, expected * 0 + 99)
644+
# Default scalar value is 0
645+
assert_array_equal(strided_scalar((2, 3, 4)), np.zeros((2, 3, 4)))
646+
647+
630648
def _check_bytes(bytes, arr):
631649
barr = np.ndarray(arr.shape, arr.dtype, buffer=bytes)
632650
assert_array_equal(barr, arr)

0 commit comments

Comments
 (0)