@@ -53,6 +53,35 @@ def update_seq(self, arr_seq):
53
53
arr_seq ._lengths = np .array (self .lengths )
54
54
55
55
56
+ def _define_operators (cls ):
57
+ """ Decorator which adds support for some Python operators. """
58
+ def _wrap (cls , op , name = None , inplace = False , unary = False ):
59
+ name = name or op
60
+ if unary :
61
+ setattr (cls , name , lambda self : self ._op (op ))
62
+ else :
63
+ setattr (cls , name ,
64
+ lambda self , value : self ._op (op , value , inplace = inplace ))
65
+
66
+ for op in ["__iadd__" , "__isub__" , "__imul__" , "__idiv__" ,
67
+ "__ifloordiv__" , "__itruediv__" , "__ior__" ]:
68
+ _wrap (cls , op , inplace = True )
69
+
70
+ for op in ["__add__" , "__sub__" , "__mul__" , "__div__" ,
71
+ "__floordiv__" , "__truediv__" , "__or__" ]:
72
+ op_ = "__i{}__" .format (op .strip ("_" ))
73
+ _wrap (cls , op_ , name = op )
74
+
75
+ for op in ["__eq__" , "__ne__" , "__lt__" , "__le__" , "__gt__" , "__ge__" ]:
76
+ _wrap (cls , op )
77
+
78
+ for op in ["__neg__" ]:
79
+ _wrap (cls , op , unary = True )
80
+
81
+ return cls
82
+
83
+
84
+ @_define_operators
56
85
class ArraySequence (object ):
57
86
""" Sequence of ndarrays having variable first dimension sizes.
58
87
@@ -120,6 +149,23 @@ def data(self):
120
149
""" Elements in this array sequence. """
121
150
return self ._data
122
151
152
+ def _check_shape (self , arrseq ):
153
+ """ Check whether this array sequence is compatible with another. """
154
+ msg = "cannot perform operation - array sequences have different"
155
+ if len (self ._lengths ) != len (arrseq ._lengths ):
156
+ msg += " lengths: {} vs. {}."
157
+ raise ValueError (msg .format (len (self ._lengths ), len (arrseq ._lengths )))
158
+
159
+ if self .total_nb_rows != arrseq .total_nb_rows :
160
+ msg += " amount of data: {} vs. {}."
161
+ raise ValueError (msg .format (self .total_nb_rows , arrseq .total_nb_rows ))
162
+
163
+ if self .common_shape != arrseq .common_shape :
164
+ msg += " common shape: {} vs. {}."
165
+ raise ValueError (msg .format (self .common_shape , arrseq .common_shape ))
166
+
167
+ return True
168
+
123
169
def _get_next_offset (self ):
124
170
""" Offset in ``self._data`` at which to write next rowelement """
125
171
if len (self ._offsets ) == 0 :
@@ -377,6 +423,37 @@ def __setitem__(self, idx, elements):
377
423
for o1 , l1 , o2 , l2 in zip (offsets , lengths , elements ._offsets , elements ._lengths ):
378
424
data [o1 :o1 + l1 ] = elements ._data [o2 :o2 + l2 ]
379
425
426
+ def _op (self , op , value = None , inplace = False ):
427
+ """ Applies some operator to this arraysequence.
428
+
429
+ This handles both unary and binary operators with a scalar or another
430
+ array sequence. Operations are performed directly on the underlying
431
+ data, or a copy of it, which depends on the value of `inplace`.
432
+
433
+ Parameters
434
+ ----------
435
+ op : str
436
+ Name of the Python operator (e.g., `"__add__"`).
437
+ value : scalar or :class:`ArraySequence`, optional
438
+ If None, the operator is assumed to be unary.
439
+ Otherwise, that value is used in the binary operation.
440
+ inplace: bool, optional
441
+ If False, the operation is done on a copy of this array sequence.
442
+ Otherwise, this array sequence gets modified directly.
443
+ """
444
+ seq = self if inplace else self .copy ()
445
+
446
+ if is_array_sequence (value ) and seq ._check_shape (value ):
447
+ for o1 , l1 , o2 , l2 in zip (seq ._offsets , seq ._lengths , value ._offsets , value ._lengths ):
448
+ seq ._data [o1 :o1 + l1 ] = getattr (seq ._data [o1 :o1 + l1 ], op )(value ._data [o2 :o2 + l2 ])
449
+
450
+ else :
451
+ args = [] if value is None else [value ] # Dealing with unary and binary ops.
452
+ for o1 , l1 in zip (seq ._offsets , seq ._lengths ):
453
+ seq ._data [o1 :o1 + l1 ] = getattr (seq ._data [o1 :o1 + l1 ], op )(* args )
454
+
455
+ return seq
456
+
380
457
def __iter__ (self ):
381
458
if len (self ._lengths ) != len (self ._offsets ):
382
459
raise ValueError ("ArraySequence object corrupted:"
0 commit comments