@@ -274,6 +274,56 @@ def test_arraysequence_getitem(self):
274
274
check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
275
275
check_arr_seq (seq_view , [d [:, 2 ] for d in SEQ_DATA ['data' ][::- 2 ]])
276
276
277
+ def test_arraysequence_setitem (self ):
278
+ # Set one item
279
+ seq = SEQ_DATA ['seq' ] * 0
280
+ for i , e in enumerate (SEQ_DATA ['seq' ]):
281
+ seq [i ] = e
282
+
283
+ check_arr_seq (seq , SEQ_DATA ['seq' ])
284
+
285
+ # Setitem with a scalar.
286
+ seq = SEQ_DATA ['seq' ].copy ()
287
+ seq [:] = 0
288
+ assert_true (seq ._data .sum () == 0 )
289
+
290
+ # Setitem with a list of ndarray.
291
+ seq = SEQ_DATA ['seq' ] * 0
292
+ seq [:] = SEQ_DATA ['data' ]
293
+ check_arr_seq (seq , SEQ_DATA ['data' ])
294
+
295
+ # Setitem using tuple indexing.
296
+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
297
+ seq [:, 0 ] = 0
298
+ assert_true (seq ._data [:, 0 ].sum () == 0 )
299
+
300
+ # Setitem using tuple indexing.
301
+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
302
+ seq [range (len (seq ))] = 0
303
+ assert_true (seq ._data .sum () == 0 )
304
+
305
+ # Setitem of a slice using another slice.
306
+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
307
+ seq [0 :4 ] = seq [5 :9 ]
308
+ check_arr_seq (seq [0 :4 ], seq [5 :9 ])
309
+
310
+ # Setitem between array sequences with different number of sequences.
311
+ seq = ArraySequence (np .arange (900 ).reshape ((50 ,6 ,3 )))
312
+ assert_raises (ValueError , seq .__setitem__ , slice (0 , 4 ), seq [5 :10 ])
313
+
314
+ # Setitem between array sequences with different amount of points.
315
+ seq1 = ArraySequence (np .arange (10 ).reshape (5 , 2 ))
316
+ seq2 = ArraySequence (np .arange (15 ).reshape (5 , 3 ))
317
+ assert_raises (ValueError , seq1 .__setitem__ , slice (0 , 5 ), seq2 )
318
+
319
+ # Setitem between array sequences with different common shape.
320
+ seq1 = ArraySequence (np .arange (12 ).reshape (2 , 2 , 3 ))
321
+ seq2 = ArraySequence (np .arange (8 ).reshape (2 , 2 , 2 ))
322
+ assert_raises (ValueError , seq1 .__setitem__ , slice (0 , 2 ), seq2 )
323
+
324
+ # Invalid index.
325
+ assert_raises (TypeError , seq .__setitem__ , object (), None )
326
+
277
327
def test_arraysequence_operators (self ):
278
328
# Disable division per zero warnings.
279
329
flags = np .seterr (divide = 'ignore' , invalid = 'ignore' )
@@ -375,61 +425,6 @@ def _test_binary(op, arrseq, scalars, seqs, inplace=False):
375
425
# Restore flags.
376
426
np .seterr (** flags )
377
427
378
-
379
- def test_arraysequence_setitem (self ):
380
- # Set one item
381
- seq = SEQ_DATA ['seq' ] * 0
382
- for i , e in enumerate (SEQ_DATA ['seq' ]):
383
- seq [i ] = e
384
-
385
- check_arr_seq (seq , SEQ_DATA ['seq' ])
386
-
387
- # Get all items using indexing (creates a view).
388
- indices = list (range (len (SEQ_DATA ['seq' ])))
389
- seq_view = SEQ_DATA ['seq' ][indices ]
390
- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
391
- # We took all elements so the view should match the original.
392
- check_arr_seq (seq_view , SEQ_DATA ['seq' ])
393
-
394
- # Get multiple items using ndarray of dtype integer.
395
- for dtype in [np .int8 , np .int16 , np .int32 , np .int64 ]:
396
- seq_view = SEQ_DATA ['seq' ][np .array (indices , dtype = dtype )]
397
- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
398
- # We took all elements so the view should match the original.
399
- check_arr_seq (seq_view , SEQ_DATA ['seq' ])
400
-
401
- # Get multiple items out of order (creates a view).
402
- SEQ_DATA ['rng' ].shuffle (indices )
403
- seq_view = SEQ_DATA ['seq' ][indices ]
404
- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
405
- check_arr_seq (seq_view , [SEQ_DATA ['data' ][i ] for i in indices ])
406
-
407
- # Get slice (this will create a view).
408
- seq_view = SEQ_DATA ['seq' ][::2 ]
409
- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
410
- check_arr_seq (seq_view , SEQ_DATA ['data' ][::2 ])
411
-
412
- # Use advanced indexing with ndarray of data type bool.
413
- selection = np .array ([False , True , True , False , True ])
414
- seq_view = SEQ_DATA ['seq' ][selection ]
415
- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
416
- check_arr_seq (seq_view ,
417
- [SEQ_DATA ['data' ][i ]
418
- for i , keep in enumerate (selection ) if keep ])
419
-
420
- # Test invalid indexing
421
- assert_raises (TypeError , SEQ_DATA ['seq' ].__getitem__ , 'abc' )
422
-
423
- # Get specific columns.
424
- seq_view = SEQ_DATA ['seq' ][:, 2 ]
425
- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
426
- check_arr_seq (seq_view , [d [:, 2 ] for d in SEQ_DATA ['data' ]])
427
-
428
- # Combining multiple slicing and indexing operations.
429
- seq_view = SEQ_DATA ['seq' ][::- 2 ][:, 2 ]
430
- check_arr_seq_view (seq_view , SEQ_DATA ['seq' ])
431
- check_arr_seq (seq_view , [d [:, 2 ] for d in SEQ_DATA ['data' ][::- 2 ]])
432
-
433
428
def test_arraysequence_repr (self ):
434
429
# Test that calling repr on a ArraySequence object is not falling.
435
430
repr (SEQ_DATA ['seq' ])
0 commit comments