@@ -29,7 +29,8 @@ def __init__(self, arr_seq, common_shape, dtype):
29
29
self .lengths = list (arr_seq ._lengths )
30
30
self .next_offset = arr_seq ._get_next_offset ()
31
31
self .bytes_per_buf = arr_seq ._buffer_size * MEGABYTE
32
- self .dtype = dtype
32
+ # Use the passed dtype only if null data array
33
+ self .dtype = dtype if arr_seq ._data .size == 0 else arr_seq ._data .dtype
33
34
if arr_seq .common_shape != () and common_shape != arr_seq .common_shape :
34
35
raise ValueError (
35
36
"All dimensions, except the first one, must match exactly" )
@@ -89,24 +90,7 @@ def __init__(self, iterable=None, buffer_size=4):
89
90
self ._is_view = True
90
91
return
91
92
92
- # If possible try pre-allocating memory.
93
- try :
94
- iter_len = len (iterable )
95
- except TypeError :
96
- pass
97
- else : # We do know the iterable length
98
- if iter_len == 0 :
99
- return
100
- first_element = np .asarray (iterable [0 ])
101
- n_elements = np .sum ([len (iterable [i ])
102
- for i in range (len (iterable ))])
103
- new_shape = (n_elements ,) + first_element .shape [1 :]
104
- self ._data = np .empty (new_shape , dtype = first_element .dtype )
105
-
106
- for e in iterable :
107
- self .append (e , cache_build = True )
108
-
109
- self .finalize_append ()
93
+ self .extend (iterable )
110
94
111
95
@property
112
96
def is_array_sequence (self ):
@@ -237,18 +221,23 @@ def extend(self, elements):
237
221
The shape of the elements to be added must match the one of the data of
238
222
this :class:`ArraySequence` except for the first dimension.
239
223
"""
240
- if not is_array_sequence (elements ):
241
- self .extend (self .__class__ (elements ))
242
- return
243
- if len (elements ) == 0 :
244
- return
245
- self ._build_cache = _BuildCache (self ,
246
- elements .common_shape ,
247
- elements .data .dtype )
248
- self ._resize_data_to (self ._get_next_offset () + elements .nb_elements ,
249
- self ._build_cache )
250
- for element in elements :
251
- self .append (element )
224
+ # If possible try pre-allocating memory.
225
+ try :
226
+ iter_len = len (elements )
227
+ except TypeError :
228
+ pass
229
+ else : # We do know the iterable length
230
+ if iter_len == 0 :
231
+ return
232
+ e0 = np .asarray (elements [0 ])
233
+ n_elements = np .sum ([len (e ) for e in elements ])
234
+ self ._build_cache = _BuildCache (self , e0 .shape [1 :], e0 .dtype )
235
+ self ._resize_data_to (self ._get_next_offset () + n_elements ,
236
+ self ._build_cache )
237
+
238
+ for e in elements :
239
+ self .append (e , cache_build = True )
240
+
252
241
self .finalize_append ()
253
242
254
243
def _extend_using_coroutine (self , buffer_size = 4 ):
0 commit comments