Skip to content

Commit 6770e84

Browse files
committed
fix #725 : binary operators +, -, *, / on sessions does not include non-LArray objects
1 parent e61a076 commit 6770e84

File tree

3 files changed

+70
-35
lines changed

3 files changed

+70
-35
lines changed

doc/source/changes/version_0_30.rst.inc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,4 +226,6 @@ Fixes
226226
to an Excel Workbook (closes :issue:`713`).
227227

228228
* fixed missing documentation of many functions in :ref:`Utility Functions <api-ufuncs>` section
229-
of the API Reference (closes :issue:`698`).
229+
of the API Reference (closes :issue:`698`).
230+
231+
* fixed arithmetic operations between two sessions returning a nan value for each axis and group (closes :issue:`725`).

larray/core/session.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -820,7 +820,7 @@ def __len__(self):
820820
return len(self._objects)
821821

822822
# binary operations are dispatched element-wise to all arrays (we consider Session as an array-like)
823-
def _binop(opname):
823+
def _binop(opname, arrays_only=True):
824824
opfullname = '__%s__' % opname
825825

826826
def opmethod(self, other):
@@ -831,21 +831,24 @@ def opmethod(self, other):
831831
with np.errstate(call=_session_float_error_handler):
832832
res = []
833833
for name in all_keys:
834-
self_array = self.get(name, nan)
834+
self_item = self.get(name, nan)
835835
other_operand = other.get(name, nan) if hasattr(other, 'get') else other
836-
try:
837-
res_array = getattr(self_array, opfullname)(other_operand)
838-
# TypeError for str arrays, ValueError for incompatible axes, ...
839-
except Exception:
840-
res_array = nan
841-
# this should only ever happen when self_array is a non Array (eg. nan)
842-
if res_array is NotImplemented:
836+
if arrays_only and not isinstance(self_item, LArray):
837+
res_item = self_item
838+
else:
843839
try:
844-
res_array = getattr(other_operand, '__%s__' % inverseop(opname))(self_array)
840+
res_item = getattr(self_item, opfullname)(other_operand)
845841
# TypeError for str arrays, ValueError for incompatible axes, ...
846842
except Exception:
847-
res_array = nan
848-
res.append((name, res_array))
843+
res_item = nan
844+
# this should only ever happen when self_array is a non Array (eg. nan)
845+
if res_item is NotImplemented:
846+
try:
847+
res_item = getattr(other_operand, '__%s__' % inverseop(opname))(self_item)
848+
# TypeError for str arrays, ValueError for incompatible axes, ...
849+
except Exception:
850+
res_item = nan
851+
res.append((name, res_item))
849852
return Session(res)
850853
opmethod.__name__ = opfullname
851854
return opmethod
@@ -859,8 +862,8 @@ def opmethod(self, other):
859862
__truediv__ = _binop('truediv')
860863
__rtruediv__ = _binop('rtruediv')
861864

862-
__eq__ = _binop('eq')
863-
__ne__ = _binop('ne')
865+
__eq__ = _binop('eq', arrays_only=False)
866+
__ne__ = _binop('ne', arrays_only=False)
864867

865868
# element-wise method factory
866869
# unary operations are (also) dispatched element-wise to all arrays

larray/tests/test_session.py

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from larray.tests.common import assert_array_nan_equal, inputpath, tmp_path, meta
1111
from larray import (Session, Axis, LArray, Group, isnan, zeros_like, ndtest, ones_like,
12-
local_arrays, global_arrays, arrays)
12+
local_arrays, global_arrays, arrays, nan)
1313
from larray.util.misc import pickle
1414

1515
try:
@@ -388,81 +388,105 @@ def test_element_equals(session):
388388

389389

390390
def test_eq(session):
391-
sess = session.filter(kind=LArray)
392-
expected = Session([('e', e), ('f', f), ('g', g)])
393-
assert all([array.all() for array in (sess == expected).values()])
391+
sess = session.filter(kind=(Axis, Group, LArray))
392+
expected = Session([('b', b), ('b12', b12), ('a', a), ('a01', a01),
393+
('e', e), ('g', g), ('f', f)])
394+
assert all([item.all() if isinstance(item, LArray) else item
395+
for item in (sess == expected).values()])
394396

395-
other = Session([('e', e), ('f', f)])
397+
other = Session([('b', b), ('b12', b12), ('a', a), ('a01', a01), ('e', e), ('f', f)])
396398
res = sess == other
397-
assert list(res.keys()) == ['e', 'g', 'f']
398-
assert [arr.all() for arr in res.values()] == [True, False, True]
399+
assert list(res.keys()) == ['b', 'b12', 'a', 'a01', 'e', 'g', 'f']
400+
assert [item.all() if isinstance(item, LArray) else item
401+
for item in res.values()] == [True, True, True, True, True, False, True]
399402

400403
e2 = e.copy()
401404
e2.i[1, 1] = 42
402-
other = Session([('e', e2), ('f', f)])
405+
other = Session([('b', b), ('b12', b12), ('a', a), ('a01', a01), ('e', e2), ('f', f)])
403406
res = sess == other
404-
assert [arr.all() for arr in res.values()] == [False, False, True]
407+
assert [item.all() if isinstance(item, LArray) else item
408+
for item in res.values()] == [True, True, True, True, False, False, True]
405409

406410

407411
def test_ne(session):
408-
sess = session.filter(kind=LArray)
409-
expected = Session([('e', e), ('f', f), ('g', g)])
410-
assert ([(~array).all() for array in (sess != expected).values()])
412+
sess = session.filter(kind=(Axis, Group, LArray))
413+
expected = Session([('b', b), ('b12', b12), ('a', a), ('a01', a01),
414+
('e', e), ('g', g), ('f', f)])
415+
assert ([(~item).all() if isinstance(item, LArray) else not item
416+
for item in (sess != expected).values()])
411417

412-
other = Session([('e', e), ('f', f)])
418+
other = Session([('b', b), ('b12', b12), ('a', a), ('a01', a01), ('e', e), ('f', f)])
413419
res = sess != other
414-
assert [(~arr).all() for arr in res.values()] == [True, False, True]
420+
assert list(res.keys()) == ['b', 'b12', 'a', 'a01', 'e', 'g', 'f']
421+
assert [(~item).all() if isinstance(item, LArray) else not item
422+
for item in res.values()] == [True, True, True, True, True, False, True]
415423

416424
e2 = e.copy()
417425
e2.i[1, 1] = 42
418-
other = Session([('e', e2), ('f', f)])
426+
other = Session([('b', b), ('b12', b12), ('a', a), ('a01', a01), ('e', e2), ('f', f)])
419427
res = sess != other
420-
assert [(~arr).all() for arr in res.values()] == [False, False, True]
428+
assert [(~item).all() if isinstance(item, LArray) else not item
429+
for item in res.values()] == [True, True, True, True, False, False, True]
421430

422431

423432
def test_sub(session):
424-
sess = session.filter(kind=LArray)
433+
sess = session
425434

426435
# session - session
427436
other = Session({'e': e - 1, 'f': ones_like(f)})
428437
diff = sess - other
429438
assert_array_nan_equal(diff['e'], np.full((2, 3), 1, dtype=np.int32))
430439
assert_array_nan_equal(diff['f'], f - ones_like(f))
431440
assert isnan(diff['g']).all()
441+
assert diff.a is a
442+
assert diff.a01 is a01
443+
assert diff.c is c
432444

433445
# session - scalar
434446
diff = sess - 2
435447
assert_array_nan_equal(diff['e'], e - 2)
436448
assert_array_nan_equal(diff['f'], f - 2)
437449
assert_array_nan_equal(diff['g'], g - 2)
450+
assert diff.a is a
451+
assert diff.a01 is a01
452+
assert diff.c is c
438453

439454
# session - dict(LArray and scalar)
440455
other = {'e': ones_like(e), 'f': 1}
441456
diff = sess - other
442457
assert_array_nan_equal(diff['e'], e - ones_like(e))
443458
assert_array_nan_equal(diff['f'], f - 1)
444459
assert isnan(diff['g']).all()
460+
assert diff.a is a
461+
assert diff.a01 is a01
462+
assert diff.c is c
445463

446464

447465
def test_rsub(session):
448-
sess = session.filter(kind=LArray)
466+
sess = session
449467

450468
# scalar - session
451469
diff = 2 - sess
452470
assert_array_nan_equal(diff['e'], 2 - e)
453471
assert_array_nan_equal(diff['f'], 2 - f)
454472
assert_array_nan_equal(diff['g'], 2 - g)
473+
assert diff.a is a
474+
assert diff.a01 is a01
475+
assert diff.c is c
455476

456477
# dict(LArray and scalar) - session
457478
other = {'e': ones_like(e), 'f': 1}
458479
diff = other - sess
459480
assert_array_nan_equal(diff['e'], ones_like(e) - e)
460481
assert_array_nan_equal(diff['f'], 1 - f)
461482
assert isnan(diff['g']).all()
483+
assert diff.a is a
484+
assert diff.a01 is a01
485+
assert diff.c is c
462486

463487

464488
def test_div(session):
465-
sess = session.filter(kind=LArray)
489+
sess = session
466490
other = Session({'e': e - 1, 'f': f + 1})
467491

468492
with pytest.warns(RuntimeWarning) as caught_warnings:
@@ -481,19 +505,25 @@ def test_div(session):
481505

482506

483507
def test_rdiv(session):
484-
sess = session.filter(kind=LArray)
508+
sess = session
485509

486510
# scalar / session
487511
res = 2 / sess
488512
assert_array_nan_equal(res['e'], 2 / e)
489513
assert_array_nan_equal(res['f'], 2 / f)
490514
assert_array_nan_equal(res['g'], 2 / g)
515+
assert res.a is a
516+
assert res.a01 is a01
517+
assert res.c is c
491518

492519
# dict(LArray and scalar) - session
493520
other = {'e': e, 'f': f}
494521
res = other / sess
495522
assert_array_nan_equal(res['e'], e / e)
496523
assert_array_nan_equal(res['f'], f / f)
524+
assert res.a is a
525+
assert res.a01 is a01
526+
assert res.c is c
497527

498528

499529
def test_pickle_roundtrip(session, meta):

0 commit comments

Comments
 (0)