5
5
6
6
import numpy as np
7
7
8
+ from ..deprecated import deprecate_with_version
9
+
8
10
MEGABYTE = 1024 * 1024
9
11
10
12
@@ -53,6 +55,37 @@ def update_seq(self, arr_seq):
53
55
arr_seq ._lengths = np .array (self .lengths )
54
56
55
57
58
+ def _define_operators (cls ):
59
+ """ Decorator which adds support for some Python operators. """
60
+ def _wrap (cls , op , inplace = False , unary = False ):
61
+
62
+ def fn_unary_op (self ):
63
+ return self ._op (op )
64
+
65
+ def fn_binary_op (self , value ):
66
+ return self ._op (op , value , inplace = inplace )
67
+
68
+ setattr (cls , op , fn_unary_op if unary else fn_binary_op )
69
+ fn = getattr (cls , op )
70
+ fn .__name__ = op
71
+ fn .__doc__ = getattr (np .ndarray , op ).__doc__
72
+
73
+ for op in ["__add__" , "__sub__" , "__mul__" , "__mod__" , "__pow__" ,
74
+ "__floordiv__" , "__truediv__" , "__lshift__" , "__rshift__" ,
75
+ "__or__" , "__and__" , "__xor__" ]:
76
+ _wrap (cls , op = op , inplace = False )
77
+ _wrap (cls , op = "__i{}__" .format (op .strip ("_" )), inplace = True )
78
+
79
+ for op in ["__eq__" , "__ne__" , "__lt__" , "__le__" , "__gt__" , "__ge__" ]:
80
+ _wrap (cls , op )
81
+
82
+ for op in ["__neg__" , "__abs__" , "__invert__" ]:
83
+ _wrap (cls , op , unary = True )
84
+
85
+ return cls
86
+
87
+
88
+ @_define_operators
56
89
class ArraySequence (object ):
57
90
""" Sequence of ndarrays having variable first dimension sizes.
58
91
@@ -116,9 +149,42 @@ def total_nb_rows(self):
116
149
return np .sum (self ._lengths )
117
150
118
151
@property
152
+ @deprecate_with_version ("'ArraySequence.data' property is deprecated.\n "
153
+ "Please use the 'ArraySequence.get_data()' method instead" ,
154
+ '3.0' , '4.0' )
119
155
def data (self ):
120
156
""" Elements in this array sequence. """
121
- return self ._data
157
+ view = self ._data .view ()
158
+ view .setflags (write = False )
159
+ return view
160
+
161
+ def get_data (self ):
162
+ """ Returns a *copy* of the elements in this array sequence.
163
+
164
+ Notes
165
+ -----
166
+ To modify the data on this array sequence, one can use
167
+ in-place mathematical operators (e.g., `seq += ...`) or the use
168
+ assignment operator (i.e, `seq[...] = value`).
169
+ """
170
+ return self .copy ()._data
171
+
172
+ def _check_shape (self , arrseq ):
173
+ """ Check whether this array sequence is compatible with another. """
174
+ msg = "cannot perform operation - array sequences have different"
175
+ if len (self ._lengths ) != len (arrseq ._lengths ):
176
+ msg += " lengths: {} vs. {}."
177
+ raise ValueError (msg .format (len (self ._lengths ), len (arrseq ._lengths )))
178
+
179
+ if self .total_nb_rows != arrseq .total_nb_rows :
180
+ msg += " amount of data: {} vs. {}."
181
+ raise ValueError (msg .format (self .total_nb_rows , arrseq .total_nb_rows ))
182
+
183
+ if self .common_shape != arrseq .common_shape :
184
+ msg += " common shape: {} vs. {}."
185
+ raise ValueError (msg .format (self .common_shape , arrseq .common_shape ))
186
+
187
+ return True
122
188
123
189
def _get_next_offset (self ):
124
190
""" Offset in ``self._data`` at which to write next rowelement """
@@ -320,7 +386,7 @@ def __getitem__(self, idx):
320
386
seq ._lengths = self ._lengths [off_idx ]
321
387
return seq
322
388
323
- if isinstance (off_idx , list ) or is_ndarray_of_int_or_bool (off_idx ):
389
+ if isinstance (off_idx , ( list , range ) ) or is_ndarray_of_int_or_bool (off_idx ):
324
390
# Fancy indexing
325
391
seq ._offsets = self ._offsets [off_idx ]
326
392
seq ._lengths = self ._lengths [off_idx ]
@@ -329,6 +395,116 @@ def __getitem__(self, idx):
329
395
raise TypeError ("Index must be either an int, a slice, a list of int"
330
396
" or a ndarray of bool! Not " + str (type (idx )))
331
397
398
+ def __setitem__ (self , idx , elements ):
399
+ """ Set sequence(s) through standard or advanced numpy indexing.
400
+
401
+ Parameters
402
+ ----------
403
+ idx : int or slice or list or ndarray
404
+ If int, index of the element to retrieve.
405
+ If slice, use slicing to retrieve elements.
406
+ If list, indices of the elements to retrieve.
407
+ If ndarray with dtype int, indices of the elements to retrieve.
408
+ If ndarray with dtype bool, only retrieve selected elements.
409
+ elements: ndarray or :class:`ArraySequence`
410
+ Data that will overwrite selected sequences.
411
+ If `idx` is an int, `elements` is expected to be a ndarray.
412
+ Otherwise, `elements` is expected a :class:`ArraySequence` object.
413
+ """
414
+ if isinstance (idx , (numbers .Integral , np .integer )):
415
+ start = self ._offsets [idx ]
416
+ self ._data [start :start + self ._lengths [idx ]] = elements
417
+ return
418
+
419
+ if isinstance (idx , tuple ):
420
+ off_idx = idx [0 ]
421
+ data = self ._data .__getitem__ ((slice (None ),) + idx [1 :])
422
+ else :
423
+ off_idx = idx
424
+ data = self ._data
425
+
426
+ if isinstance (off_idx , slice ): # Standard list slicing
427
+ offsets = self ._offsets [off_idx ]
428
+ lengths = self ._lengths [off_idx ]
429
+
430
+ elif isinstance (off_idx , (list , range )) or is_ndarray_of_int_or_bool (off_idx ):
431
+ # Fancy indexing
432
+ offsets = self ._offsets [off_idx ]
433
+ lengths = self ._lengths [off_idx ]
434
+
435
+ else :
436
+ raise TypeError ("Index must be either an int, a slice, a list of int"
437
+ " or a ndarray of bool! Not " + str (type (idx )))
438
+
439
+ if is_array_sequence (elements ):
440
+ if len (lengths ) != len (elements ):
441
+ msg = "Trying to set {} sequences with {} sequences."
442
+ raise ValueError (msg .format (len (lengths ), len (elements )))
443
+
444
+ if sum (lengths ) != elements .total_nb_rows :
445
+ msg = "Trying to set {} points with {} points."
446
+ raise ValueError (msg .format (sum (lengths ), elements .total_nb_rows ))
447
+
448
+ for o1 , l1 , o2 , l2 in zip (offsets , lengths , elements ._offsets , elements ._lengths ):
449
+ data [o1 :o1 + l1 ] = elements ._data [o2 :o2 + l2 ]
450
+
451
+ elif isinstance (elements , numbers .Number ):
452
+ for o1 , l1 in zip (offsets , lengths ):
453
+ data [o1 :o1 + l1 ] = elements
454
+
455
+ else : # Try to iterate over it.
456
+ for o1 , l1 , element in zip (offsets , lengths , elements ):
457
+ data [o1 :o1 + l1 ] = element
458
+
459
+ def _op (self , op , value = None , inplace = False ):
460
+ """ Applies some operator to this arraysequence.
461
+
462
+ This handles both unary and binary operators with a scalar or another
463
+ array sequence. Operations are performed directly on the underlying
464
+ data, or a copy of it, which depends on the value of `inplace`.
465
+
466
+ Parameters
467
+ ----------
468
+ op : str
469
+ Name of the Python operator (e.g., `"__add__"`).
470
+ value : scalar or :class:`ArraySequence`, optional
471
+ If None, the operator is assumed to be unary.
472
+ Otherwise, that value is used in the binary operation.
473
+ inplace: bool, optional
474
+ If False, the operation is done on a copy of this array sequence.
475
+ Otherwise, this array sequence gets modified directly.
476
+ """
477
+ seq = self if inplace else self .copy ()
478
+
479
+ if is_array_sequence (value ) and seq ._check_shape (value ):
480
+ elements = zip (seq ._offsets , seq ._lengths ,
481
+ self ._offsets , self ._lengths ,
482
+ value ._offsets , value ._lengths )
483
+
484
+ # Change seq.dtype to match the operation resulting type.
485
+ o0 , l0 , o1 , l1 , o2 , l2 = next (elements )
486
+ tmp = getattr (self ._data [o1 :o1 + l1 ], op )(value ._data [o2 :o2 + l2 ])
487
+ seq ._data = seq ._data .astype (tmp .dtype )
488
+ seq ._data [o0 :o0 + l0 ] = tmp
489
+
490
+ for o0 , l0 , o1 , l1 , o2 , l2 in elements :
491
+ seq ._data [o0 :o0 + l0 ] = getattr (self ._data [o1 :o1 + l1 ], op )(value ._data [o2 :o2 + l2 ])
492
+
493
+ else :
494
+ args = [] if value is None else [value ] # Dealing with unary and binary ops.
495
+ elements = zip (seq ._offsets , seq ._lengths , self ._offsets , self ._lengths )
496
+
497
+ # Change seq.dtype to match the operation resulting type.
498
+ o0 , l0 , o1 , l1 = next (elements )
499
+ tmp = getattr (self ._data [o1 :o1 + l1 ], op )(* args )
500
+ seq ._data = seq ._data .astype (tmp .dtype )
501
+ seq ._data [o0 :o0 + l0 ] = tmp
502
+
503
+ for o0 , l0 , o1 , l1 in elements :
504
+ seq ._data [o0 :o0 + l0 ] = getattr (self ._data [o1 :o1 + l1 ], op )(* args )
505
+
506
+ return seq
507
+
332
508
def __iter__ (self ):
333
509
if len (self ._lengths ) != len (self ._offsets ):
334
510
raise ValueError ("ArraySequence object corrupted:"
@@ -371,7 +547,7 @@ def load(cls, filename):
371
547
return seq
372
548
373
549
374
- def create_arraysequences_from_generator (gen , n ):
550
+ def create_arraysequences_from_generator (gen , n , buffer_sizes = None ):
375
551
""" Creates :class:`ArraySequence` objects from a generator yielding tuples
376
552
377
553
Parameters
@@ -381,8 +557,13 @@ def create_arraysequences_from_generator(gen, n):
381
557
array sequences.
382
558
n : int
383
559
Number of :class:`ArraySequences` object to create.
560
+ buffer_sizes : list of float, optional
561
+ Sizes (in Mb) for each ArraySequence's buffer.
384
562
"""
385
- seqs = [ArraySequence () for _ in range (n )]
563
+ if buffer_sizes is None :
564
+ buffer_sizes = [4 ] * n
565
+
566
+ seqs = [ArraySequence (buffer_size = size ) for size in buffer_sizes ]
386
567
for data in gen :
387
568
for i , seq in enumerate (seqs ):
388
569
if data [i ].nbytes > 0 :
0 commit comments