9
9
10
10
from larray .tests .common import assert_array_nan_equal , inputpath , tmp_path , meta
11
11
from larray import (Session , Axis , LArray , Group , isnan , zeros_like , ndtest , ones_like ,
12
- local_arrays , global_arrays , arrays )
12
+ local_arrays , global_arrays , arrays , nan )
13
13
from larray .util .misc import pickle
14
14
15
15
try :
@@ -388,81 +388,105 @@ def test_element_equals(session):
388
388
389
389
390
390
def test_eq (session ):
391
- sess = session .filter (kind = LArray )
392
- expected = Session ([('e' , e ), ('f' , f ), ('g' , g )])
393
- assert all ([array .all () for array in (sess == expected ).values ()])
391
+ sess = session .filter (kind = (Axis , Group , LArray ))
392
+ expected = Session ([('b' , b ), ('b12' , b12 ), ('a' , a ), ('a01' , a01 ),
393
+ ('e' , e ), ('g' , g ), ('f' , f )])
394
+ assert all ([item .all () if isinstance (item , LArray ) else item
395
+ for item in (sess == expected ).values ()])
394
396
395
- other = Session ([('e' , e ), ('f' , f )])
397
+ other = Session ([('b' , b ), ( 'b12' , b12 ), ( 'a' , a ), ( 'a01' , a01 ), ( ' e' , e ), ('f' , f )])
396
398
res = sess == other
397
- assert list (res .keys ()) == ['e' , 'g' , 'f' ]
398
- assert [arr .all () for arr in res .values ()] == [True , False , True ]
399
+ assert list (res .keys ()) == ['b' , 'b12' , 'a' , 'a01' , 'e' , 'g' , 'f' ]
400
+ assert [item .all () if isinstance (item , LArray ) else item
401
+ for item in res .values ()] == [True , True , True , True , True , False , True ]
399
402
400
403
e2 = e .copy ()
401
404
e2 .i [1 , 1 ] = 42
402
- other = Session ([('e' , e2 ), ('f' , f )])
405
+ other = Session ([('b' , b ), ( 'b12' , b12 ), ( 'a' , a ), ( 'a01' , a01 ), ( ' e' , e2 ), ('f' , f )])
403
406
res = sess == other
404
- assert [arr .all () for arr in res .values ()] == [False , False , True ]
407
+ assert [item .all () if isinstance (item , LArray ) else item
408
+ for item in res .values ()] == [True , True , True , True , False , False , True ]
405
409
406
410
407
411
def test_ne (session ):
408
- sess = session .filter (kind = LArray )
409
- expected = Session ([('e' , e ), ('f' , f ), ('g' , g )])
410
- assert ([(~ array ).all () for array in (sess != expected ).values ()])
412
+ sess = session .filter (kind = (Axis , Group , LArray ))
413
+ expected = Session ([('b' , b ), ('b12' , b12 ), ('a' , a ), ('a01' , a01 ),
414
+ ('e' , e ), ('g' , g ), ('f' , f )])
415
+ assert ([(~ item ).all () if isinstance (item , LArray ) else not item
416
+ for item in (sess != expected ).values ()])
411
417
412
- other = Session ([('e' , e ), ('f' , f )])
418
+ other = Session ([('b' , b ), ( 'b12' , b12 ), ( 'a' , a ), ( 'a01' , a01 ), ( ' e' , e ), ('f' , f )])
413
419
res = sess != other
414
- assert [(~ arr ).all () for arr in res .values ()] == [True , False , True ]
420
+ assert list (res .keys ()) == ['b' , 'b12' , 'a' , 'a01' , 'e' , 'g' , 'f' ]
421
+ assert [(~ item ).all () if isinstance (item , LArray ) else not item
422
+ for item in res .values ()] == [True , True , True , True , True , False , True ]
415
423
416
424
e2 = e .copy ()
417
425
e2 .i [1 , 1 ] = 42
418
- other = Session ([('e' , e2 ), ('f' , f )])
426
+ other = Session ([('b' , b ), ( 'b12' , b12 ), ( 'a' , a ), ( 'a01' , a01 ), ( ' e' , e2 ), ('f' , f )])
419
427
res = sess != other
420
- assert [(~ arr ).all () for arr in res .values ()] == [False , False , True ]
428
+ assert [(~ item ).all () if isinstance (item , LArray ) else not item
429
+ for item in res .values ()] == [True , True , True , True , False , False , True ]
421
430
422
431
423
432
def test_sub (session ):
424
- sess = session . filter ( kind = LArray )
433
+ sess = session
425
434
426
435
# session - session
427
436
other = Session ({'e' : e - 1 , 'f' : ones_like (f )})
428
437
diff = sess - other
429
438
assert_array_nan_equal (diff ['e' ], np .full ((2 , 3 ), 1 , dtype = np .int32 ))
430
439
assert_array_nan_equal (diff ['f' ], f - ones_like (f ))
431
440
assert isnan (diff ['g' ]).all ()
441
+ assert diff .a is a
442
+ assert diff .a01 is a01
443
+ assert diff .c is c
432
444
433
445
# session - scalar
434
446
diff = sess - 2
435
447
assert_array_nan_equal (diff ['e' ], e - 2 )
436
448
assert_array_nan_equal (diff ['f' ], f - 2 )
437
449
assert_array_nan_equal (diff ['g' ], g - 2 )
450
+ assert diff .a is a
451
+ assert diff .a01 is a01
452
+ assert diff .c is c
438
453
439
454
# session - dict(LArray and scalar)
440
455
other = {'e' : ones_like (e ), 'f' : 1 }
441
456
diff = sess - other
442
457
assert_array_nan_equal (diff ['e' ], e - ones_like (e ))
443
458
assert_array_nan_equal (diff ['f' ], f - 1 )
444
459
assert isnan (diff ['g' ]).all ()
460
+ assert diff .a is a
461
+ assert diff .a01 is a01
462
+ assert diff .c is c
445
463
446
464
447
465
def test_rsub (session ):
448
- sess = session . filter ( kind = LArray )
466
+ sess = session
449
467
450
468
# scalar - session
451
469
diff = 2 - sess
452
470
assert_array_nan_equal (diff ['e' ], 2 - e )
453
471
assert_array_nan_equal (diff ['f' ], 2 - f )
454
472
assert_array_nan_equal (diff ['g' ], 2 - g )
473
+ assert diff .a is a
474
+ assert diff .a01 is a01
475
+ assert diff .c is c
455
476
456
477
# dict(LArray and scalar) - session
457
478
other = {'e' : ones_like (e ), 'f' : 1 }
458
479
diff = other - sess
459
480
assert_array_nan_equal (diff ['e' ], ones_like (e ) - e )
460
481
assert_array_nan_equal (diff ['f' ], 1 - f )
461
482
assert isnan (diff ['g' ]).all ()
483
+ assert diff .a is a
484
+ assert diff .a01 is a01
485
+ assert diff .c is c
462
486
463
487
464
488
def test_div (session ):
465
- sess = session . filter ( kind = LArray )
489
+ sess = session
466
490
other = Session ({'e' : e - 1 , 'f' : f + 1 })
467
491
468
492
with pytest .warns (RuntimeWarning ) as caught_warnings :
@@ -481,19 +505,25 @@ def test_div(session):
481
505
482
506
483
507
def test_rdiv (session ):
484
- sess = session . filter ( kind = LArray )
508
+ sess = session
485
509
486
510
# scalar / session
487
511
res = 2 / sess
488
512
assert_array_nan_equal (res ['e' ], 2 / e )
489
513
assert_array_nan_equal (res ['f' ], 2 / f )
490
514
assert_array_nan_equal (res ['g' ], 2 / g )
515
+ assert res .a is a
516
+ assert res .a01 is a01
517
+ assert res .c is c
491
518
492
519
# dict(LArray and scalar) - session
493
520
other = {'e' : e , 'f' : f }
494
521
res = other / sess
495
522
assert_array_nan_equal (res ['e' ], e / e )
496
523
assert_array_nan_equal (res ['f' ], f / f )
524
+ assert res .a is a
525
+ assert res .a01 is a01
526
+ assert res .c is c
497
527
498
528
499
529
def test_pickle_roundtrip (session , meta ):
0 commit comments