Skip to content

Commit e3b4db5

Browse files
committed
RF: refactor to used cached append method
Use caching of append parameters to speed up append for multiple passes.
1 parent d62facc commit e3b4db5

File tree

2 files changed

+148
-66
lines changed

2 files changed

+148
-66
lines changed

nibabel/streamlines/array_sequence.py

Lines changed: 134 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
1+
from __future__ import division
2+
13
import numbers
4+
from operator import mul
5+
from functools import reduce
6+
27
import numpy as np
38

9+
MEGABYTE = 1024 * 1024
10+
411

512
def is_array_sequence(obj):
613
""" Return True if `obj` is an array sequence. """
@@ -16,6 +23,26 @@ def is_ndarray_of_int_or_bool(obj):
1623
np.issubdtype(obj.dtype, np.bool)))
1724

1825

26+
class _BuildCache(object):
27+
def __init__(self, arr_seq, common_shape, dtype):
28+
self.offsets = list(arr_seq._offsets)
29+
self.lengths = list(arr_seq._lengths)
30+
self.next_offset = arr_seq._get_next_offset()
31+
self.bytes_per_buf = arr_seq._buffer_size * MEGABYTE
32+
self.dtype = dtype
33+
if arr_seq.common_shape != () and common_shape != arr_seq.common_shape:
34+
raise ValueError(
35+
"All dimensions, except the first one, must match exactly")
36+
self.common_shape = common_shape
37+
n_in_row = reduce(mul, common_shape, 1)
38+
bytes_per_row = n_in_row * dtype.itemsize
39+
self.rows_per_buf = bytes_per_row / self.bytes_per_buf
40+
41+
def update_seq(self, arr_seq):
42+
arr_seq._offsets = np.array(self.offsets)
43+
arr_seq._lengths = np.array(self.lengths)
44+
45+
1946
class ArraySequence(object):
2047
""" Sequence of ndarrays having variable first dimension sizes.
2148
@@ -48,6 +75,8 @@ def __init__(self, iterable=None, buffer_size=4):
4875
self._data = np.array([])
4976
self._offsets = np.array([], dtype=np.intp)
5077
self._lengths = np.array([], dtype=np.intp)
78+
self._buffer_size = buffer_size
79+
self._build_cache = None
5180

5281
if iterable is None:
5382
return
@@ -60,25 +89,24 @@ def __init__(self, iterable=None, buffer_size=4):
6089
self._is_view = True
6190
return
6291

92+
# If possible try pre-allocating memory.
6393
try:
64-
# If possible try pre-allocating memory.
65-
if len(iterable) > 0:
66-
first_element = np.asarray(iterable[0])
67-
n_elements = np.sum([len(iterable[i])
68-
for i in range(len(iterable))])
69-
new_shape = (n_elements,) + first_element.shape[1:]
70-
self._data = np.empty(new_shape, dtype=first_element.dtype)
94+
iter_len = len(iterable)
7195
except TypeError:
7296
pass
73-
74-
# Initialize the `ArraySequence` object from iterable's item.
75-
coroutine = self._extend_using_coroutine()
76-
coroutine.send(None) # Run until the first yield.
97+
else: # We do know the iterable length
98+
if iter_len == 0:
99+
return
100+
first_element = np.asarray(iterable[0])
101+
n_elements = np.sum([len(iterable[i])
102+
for i in range(len(iterable))])
103+
new_shape = (n_elements,) + first_element.shape[1:]
104+
self._data = np.empty(new_shape, dtype=first_element.dtype)
77105

78106
for e in iterable:
79-
coroutine.send(e)
107+
self.append(e, cache_build=True)
80108

81-
coroutine.close() # Terminate coroutine.
109+
self.finalize_append()
82110

83111
@property
84112
def is_array_sequence(self):
@@ -92,21 +120,40 @@ def common_shape(self):
92120
@property
93121
def nb_elements(self):
94122
""" Total number of elements in this array sequence. """
95-
return self._data.shape[0]
123+
return np.sum(self._lengths)
96124

97125
@property
98126
def data(self):
99127
""" Elements in this array sequence. """
100128
return self._data
101129

102-
def append(self, element):
130+
def _get_next_offset(self):
131+
""" Offset in ``self._data`` at which to write next element """
132+
if len(self._offsets) == 0:
133+
return 0
134+
imax = np.argmax(self._offsets)
135+
return self._offsets[imax] + self._lengths[imax]
136+
137+
def append(self, element, cache_build=False):
103138
""" Appends `element` to this array sequence.
104139
140+
Append can be a lot faster if it knows that it is appending several
141+
elements instead of a single element. In that case it can cache the
142+
parameters it uses between append operations, in a "build cache". To
143+
tell append to do this, use ``cache_build=True``. If you use
144+
``cache_build=True``, you need to finalize the append operations with
145+
:method:`finalize_append`.
146+
105147
Parameters
106148
----------
107149
element : ndarray
108150
Element to append. The shape must match already inserted elements
109151
shape except for the first dimension.
152+
cache_build : {False, True}
153+
Whether to save the build cache from this append routine. If True,
154+
append can assume it is the only player updating `self`, and the
155+
caller must finalize `self` after all append operations, with
156+
``self.finalize_append()``.
110157
111158
Returns
112159
-------
@@ -118,17 +165,56 @@ def append(self, element):
118165
`ArraySequence.extend`.
119166
"""
120167
element = np.asarray(element)
168+
if element.size == 0:
169+
return
170+
el_shape = element.shape
171+
n_items, common_shape = el_shape[0], el_shape[1:]
172+
build_cache = self._build_cache
173+
in_cached_build = build_cache is not None
174+
if not in_cached_build: # One shot append, not part of sequence
175+
build_cache = _BuildCache(self, common_shape, element.dtype)
176+
next_offset = build_cache.next_offset
177+
req_rows = next_offset + n_items
178+
if self._data.shape[0] < req_rows:
179+
self._resize_data_to(req_rows, build_cache)
180+
self._data[next_offset:req_rows] = element
181+
build_cache.offsets.append(next_offset)
182+
build_cache.lengths.append(n_items)
183+
build_cache.next_offset = req_rows
184+
if in_cached_build:
185+
return
186+
if cache_build:
187+
self._build_cache = build_cache
188+
else:
189+
build_cache.update_seq(self)
121190

122-
if self.common_shape != () and element.shape[1:] != self.common_shape:
123-
msg = "All dimensions, except the first one, must match exactly"
124-
raise ValueError(msg)
191+
def finalize_append(self):
192+
""" Finalize process of appending several elements to `self`
125193
126-
next_offset = self._data.shape[0]
127-
size = (self._data.shape[0] + element.shape[0],) + element.shape[1:]
128-
self._data.resize(size)
129-
self._data[next_offset:] = element
130-
self._offsets = np.r_[self._offsets, next_offset]
131-
self._lengths = np.r_[self._lengths, element.shape[0]]
194+
:method:`append` can be a lot faster if it knows that it is appending
195+
several elements instead of a single element. To tell the append
196+
method this is the case, use ``cache_build=True``. This method
197+
finalizes the series of append operations after a call to
198+
:method:`append` with ``cache_build=True``.
199+
"""
200+
if self._build_cache is None:
201+
return
202+
self._build_cache.update_seq(self)
203+
self._build_cache = None
204+
205+
def _resize_data_to(self, n_rows, build_cache):
206+
""" Resize data array if required """
207+
# Calculate new data shape, rounding up to nearest buffer size
208+
n_bufs = np.ceil(n_rows / build_cache.rows_per_buf)
209+
extended_n_rows = int(n_bufs * build_cache.rows_per_buf)
210+
new_shape = (extended_n_rows,) + build_cache.common_shape
211+
if self._data.size == 0:
212+
self._data = np.empty(new_shape, dtype=build_cache.dtype)
213+
else:
214+
self._data.resize(new_shape)
215+
216+
def shrink_data(self):
217+
self._data.resize((self._get_next_offset(),) + self.common_shape)
132218

133219
def extend(self, elements):
134220
""" Appends all `elements` to this array sequence.
@@ -154,28 +240,16 @@ def extend(self, elements):
154240
if not is_array_sequence(elements):
155241
self.extend(self.__class__(elements))
156242
return
157-
158243
if len(elements) == 0:
159244
return
160-
161-
if (self.common_shape != () and
162-
elements.common_shape != self.common_shape):
163-
msg = "All dimensions, except the first one, must match exactly"
164-
raise ValueError(msg)
165-
166-
next_offset = self._data.shape[0]
167-
self._data.resize((self._data.shape[0] + sum(elements._lengths),
168-
elements._data.shape[1]))
169-
170-
offsets = []
171-
for offset, length in zip(elements._offsets, elements._lengths):
172-
offsets.append(next_offset)
173-
chunk = elements._data[offset:offset + length]
174-
self._data[next_offset:next_offset + length] = chunk
175-
next_offset += length
176-
177-
self._lengths = np.r_[self._lengths, elements._lengths]
178-
self._offsets = np.r_[self._offsets, offsets]
245+
self._build_cache = _BuildCache(self,
246+
elements.common_shape,
247+
elements.data.dtype)
248+
self._resize_data_to(self._get_next_offset() + elements.nb_elements,
249+
self._build_cache)
250+
for element in elements:
251+
self.append(element)
252+
self.finalize_append()
179253

180254
def _extend_using_coroutine(self, buffer_size=4):
181255
""" Creates a coroutine allowing to append elements.
@@ -204,7 +278,7 @@ def _extend_using_coroutine(self, buffer_size=4):
204278
offsets = []
205279
lengths = []
206280

207-
offset = 0 if len(self) == 0 else self._offsets[-1] + self._lengths[-1]
281+
offset = self._get_next_offset()
208282
try:
209283
first_element = True
210284
while True:
@@ -293,20 +367,24 @@ def __getitem__(self, idx):
293367
start = self._offsets[idx]
294368
return self._data[start:start + self._lengths[idx]]
295369

296-
elif isinstance(idx, (slice, list)) or is_ndarray_of_int_or_bool(idx):
297-
seq = self.__class__()
370+
seq = self.__class__()
371+
seq._is_view = True
372+
if isinstance(idx, tuple):
373+
off_idx = idx[0]
374+
seq._data = self._data.__getitem__((slice(None),) + idx[1:])
375+
else:
376+
off_idx = idx
298377
seq._data = self._data
299-
seq._offsets = self._offsets[idx]
300-
seq._lengths = self._lengths[idx]
301-
seq._is_view = True
378+
379+
if isinstance(off_idx, slice): # Standard list slicing
380+
seq._offsets = self._offsets[off_idx]
381+
seq._lengths = self._lengths[off_idx]
302382
return seq
303383

304-
elif isinstance(idx, tuple):
305-
seq = self.__class__()
306-
seq._data = self._data.__getitem__((slice(None),) + idx[1:])
307-
seq._offsets = self._offsets[idx[0]]
308-
seq._lengths = self._lengths[idx[0]]
309-
seq._is_view = True
384+
if isinstance(off_idx, list) or is_ndarray_of_int_or_bool(off_idx):
385+
# Fancy indexing
386+
seq._offsets = self._offsets[off_idx]
387+
seq._lengths = self._lengths[off_idx]
310388
return seq
311389

312390
raise TypeError("Index must be either an int, a slice, a list of int"

nibabel/streamlines/tests/test_array_sequence.py

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ def check_arr_seq(seq, arrays):
5252
# The only thing we can check is the _lengths.
5353
assert_array_equal(sorted(seq._lengths), sorted(lengths))
5454
else:
55+
seq.shrink_data()
5556
assert_equal(seq._data.shape[0], sum(lengths))
5657
assert_array_equal(seq._data, np.concatenate(arrays, axis=0))
5758
assert_array_equal(seq._offsets, np.r_[0, np.cumsum(lengths)[:-1]])
@@ -113,20 +114,23 @@ def test_arraysequence_iter(self):
113114
assert_raises(ValueError, list, seq)
114115

115116
def test_arraysequence_copy(self):
116-
seq = SEQ_DATA['seq'].copy()
117-
assert_array_equal(seq._data, SEQ_DATA['seq']._data)
118-
assert_true(seq._data is not SEQ_DATA['seq']._data)
119-
assert_array_equal(seq._offsets, SEQ_DATA['seq']._offsets)
120-
assert_true(seq._offsets is not SEQ_DATA['seq']._offsets)
121-
assert_array_equal(seq._lengths, SEQ_DATA['seq']._lengths)
122-
assert_true(seq._lengths is not SEQ_DATA['seq']._lengths)
123-
assert_equal(seq.common_shape, SEQ_DATA['seq'].common_shape)
117+
orig = SEQ_DATA['seq']
118+
seq = orig.copy()
119+
n_rows = seq.nb_elements
120+
assert_equal(n_rows, orig.nb_elements)
121+
assert_array_equal(seq._data, orig._data[:n_rows])
122+
assert_true(seq._data is not orig._data)
123+
assert_array_equal(seq._offsets, orig._offsets)
124+
assert_true(seq._offsets is not orig._offsets)
125+
assert_array_equal(seq._lengths, orig._lengths)
126+
assert_true(seq._lengths is not orig._lengths)
127+
assert_equal(seq.common_shape, orig.common_shape)
124128

125129
# Taking a copy of an `ArraySequence` generated by slicing.
126130
# Only keep needed data.
127-
seq = SEQ_DATA['seq'][::2].copy()
131+
seq = orig[::2].copy()
128132
check_arr_seq(seq, SEQ_DATA['data'][::2])
129-
assert_true(seq._data is not SEQ_DATA['seq']._data)
133+
assert_true(seq._data is not orig._data)
130134

131135
def test_arraysequence_append(self):
132136
element = generate_data(nb_arrays=1,

0 commit comments

Comments
 (0)