Skip to content

Commit 1eff988

Browse files
committed
fixed binary ops between LArray and scalar Groups which silently gave wrong results (closes #797)
1 parent 6d27e32 commit 1eff988

File tree

3 files changed

+18
-0
lines changed

3 files changed

+18
-0
lines changed

doc/source/changes/version_0_31.rst.inc

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,9 @@ Miscellaneous improvements
3333
Fixes
3434
^^^^^
3535

36+
* fixed binary operations (+, -, *, etc.) between an LArray and a (scalar) Group which silently gave a wrong
37+
result (closes :issue:`797`).
38+
3639
* fixed taking a subset of an array with boolean labels for an axis if the user explicitly specify the axis
3740
(closes :issue:`735`). When the user does not specify the axis, it currently fails but it is unclear what to do in
3841
that case (see :issue:`794`).

larray/core/array.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5476,6 +5476,13 @@ def opmethod(self, other):
54765476
if isinstance(other, ExprNode):
54775477
other = other.evaluate(self.axes)
54785478

5479+
# XXX: unsure what happens for non scalar Groups.
5480+
# we might want to be more general than this and .eval all Groups?
5481+
# or (and I think it's better) define __larray__ on Group
5482+
# so that a non scalar Group acts like an Axis in this situation.
5483+
if isinstance(other, Group) and np.isscalar(other.key):
5484+
other = other.eval()
5485+
54795486
# we could pass scalars through aslarray too but it is too costly performance-wise for only suppressing one
54805487
# isscalar test and an if statement.
54815488
# TODO: ndarray should probably be converted to larrays because that would harmonize broadcasting rules, but

larray/tests/test_array.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2733,6 +2733,14 @@ def test_broadcasting_no_name():
27332733
np.asarray(a) * np.asarray(c)
27342734

27352735

2736+
def test_binary_ops_with_scalar_group():
2737+
time = Axis('time=2015..2019')
2738+
arr = ndtest(3)
2739+
expected = arr + 2015
2740+
assert_larray_equal(time.i[0] + arr, expected)
2741+
assert_larray_equal(arr + time.i[0], expected)
2742+
2743+
27362744
def test_unary_ops(small_array):
27372745
raw = small_array.data
27382746

0 commit comments

Comments
 (0)