Skip to content

Commit 7fabe9b

Browse files
committed
Merge pull request #8 from matthew-brett/refactor-aseq-init
RF: refactor init + extend for arraysequence
2 parents 0a1807f + acfbf39 commit 7fabe9b

File tree

1 file changed

+20
-31
lines changed

1 file changed

+20
-31
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 20 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,8 @@ def __init__(self, arr_seq, common_shape, dtype):
2929
self.lengths = list(arr_seq._lengths)
3030
self.next_offset = arr_seq._get_next_offset()
3131
self.bytes_per_buf = arr_seq._buffer_size * MEGABYTE
32-
self.dtype = dtype
32+
# Use the passed dtype only if null data array
33+
self.dtype = dtype if arr_seq._data.size == 0 else arr_seq._data.dtype
3334
if arr_seq.common_shape != () and common_shape != arr_seq.common_shape:
3435
raise ValueError(
3536
"All dimensions, except the first one, must match exactly")
@@ -89,24 +90,7 @@ def __init__(self, iterable=None, buffer_size=4):
8990
self._is_view = True
9091
return
9192

92-
# If possible try pre-allocating memory.
93-
try:
94-
iter_len = len(iterable)
95-
except TypeError:
96-
pass
97-
else: # We do know the iterable length
98-
if iter_len == 0:
99-
return
100-
first_element = np.asarray(iterable[0])
101-
n_elements = np.sum([len(iterable[i])
102-
for i in range(len(iterable))])
103-
new_shape = (n_elements,) + first_element.shape[1:]
104-
self._data = np.empty(new_shape, dtype=first_element.dtype)
105-
106-
for e in iterable:
107-
self.append(e, cache_build=True)
108-
109-
self.finalize_append()
93+
self.extend(iterable)
11094

11195
@property
11296
def is_array_sequence(self):
@@ -237,18 +221,23 @@ def extend(self, elements):
237221
The shape of the elements to be added must match the one of the data of
238222
this :class:`ArraySequence` except for the first dimension.
239223
"""
240-
if not is_array_sequence(elements):
241-
self.extend(self.__class__(elements))
242-
return
243-
if len(elements) == 0:
244-
return
245-
self._build_cache = _BuildCache(self,
246-
elements.common_shape,
247-
elements.data.dtype)
248-
self._resize_data_to(self._get_next_offset() + elements.nb_elements,
249-
self._build_cache)
250-
for element in elements:
251-
self.append(element)
224+
# If possible try pre-allocating memory.
225+
try:
226+
iter_len = len(elements)
227+
except TypeError:
228+
pass
229+
else: # We do know the iterable length
230+
if iter_len == 0:
231+
return
232+
e0 = np.asarray(elements[0])
233+
n_elements = np.sum([len(e) for e in elements])
234+
self._build_cache = _BuildCache(self, e0.shape[1:], e0.dtype)
235+
self._resize_data_to(self._get_next_offset() + n_elements,
236+
self._build_cache)
237+
238+
for e in elements:
239+
self.append(e, cache_build=True)
240+
252241
self.finalize_append()
253242

254243
def _extend_using_coroutine(self, buffer_size=4):

0 commit comments

Comments
 (0)