Skip to content

Commit ac91fef

Browse files
authored
Fix compatibility for binary operations (#142)
* Fix: Binary operations on two Grid objects with different edges parameters now correctly raise a TypeError as documented * Enhancement: any broadcastable object can now be used for binary operations on grids * Updated AUTHORS * Updated CHANGELOG (bump to 1.1.0)
1 parent 8b55865 commit ac91fef

File tree

4 files changed

+39
-11
lines changed

4 files changed

+39
-11
lines changed

AUTHORS

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,5 @@ Contributors:
2222
* Lily Wang <lilyminium>
2323
* Josh Vermaas <jvermaas>
2424
* Irfan Alibay <IAlibay>
25-
* Zhiyi Wu <xiki-tempula>
25+
* Zhiyi Wu <xiki-tempula>
26+
* Olivier Languin-Cattoën <ollyfutur>

CHANGELOG

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,26 @@ The rules for this file:
1313
* accompany each entry with github issue/PR number (Issue #xyz)
1414

1515
------------------------------------------------------------------------------
16-
??/??/???? IAlibay
17-
* 1.0.3
16+
??/??/???? IAlibay, ollyfutur
17+
* 1.1.0
1818

1919
Changes
2020

2121
* Python 3.13 and 3.14 are now supported (PR #140)
2222
* Support for Python 3.9 and 3.10 is now dropped as per SPEC0 (PR #140)
2323

24+
Enhancements
25+
26+
* `Grid` now accepts binary operations with any operand that can be
27+
broadcasted to the grid's shape according to `numpy` broadcasting rules
28+
(PR #142)
29+
30+
Fixes
31+
32+
* Attempting binary operations on grids with different edges now raises an
33+
exception (PR #142)
34+
35+
2436
10/21/2023 IAlibay, orbeckst, lilyminium
2537

2638
* 1.0.2

gridData/core.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -711,7 +711,7 @@ def check_compatible(self, other):
711711
712712
`other` is compatible if
713713
714-
1) `other` is a scalar
714+
1) `other` is a scalar or an array-like broadcastable to the grid
715715
2) `other` is a grid defined on the same edges
716716
717717
In order to make `other` compatible, resample it on the same
@@ -732,13 +732,22 @@ def check_compatible(self, other):
732732
--------
733733
:meth:`resample`
734734
"""
735-
736-
if not (numpy.isreal(other) or self == other):
735+
if isinstance(other, Grid):
736+
is_compatible = all(
737+
numpy.allclose(other_edge, self_edge)
738+
for other_edge, self_edge in zip(other.edges, self.edges)
739+
)
740+
else:
741+
try:
742+
is_compatible = numpy.broadcast(self.grid, other).shape == self.grid.shape
743+
except ValueError:
744+
is_compatible = False
745+
if not is_compatible:
737746
raise TypeError(
738747
"The argument cannot be arithmetically combined with the grid. "
739-
"It must be a scalar or a grid with identical edges. "
740-
"Use Grid.resample(other.edges) to make a new grid that is "
741-
"compatible with other.")
748+
"It must be broadcastable to the grid's shape or a `Grid` with identical edges. "
749+
"Use `Grid.resample(other.edges)` to make a new grid that is "
750+
"compatible with `other`.")
742751
return True
743752

744753
def _interpolationFunctionFactory(self, spline_order=None, cval=None):

gridData/tests/test_grid.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -107,12 +107,18 @@ def test_power(self, data):
107107
def test_compatibility_type(self, data):
108108
assert data['grid'].check_compatible(data['grid'])
109109
assert data['grid'].check_compatible(3)
110-
g = Grid(data['griddata'], origin=data['origin'] - 1, delta=data['delta'])
110+
g = Grid(data['griddata'], origin=data['origin'], delta=data['delta'])
111111
assert data['grid'].check_compatible(g)
112+
assert data['grid'].check_compatible(g.grid)
112113

113114
def test_wrong_compatibile_type(self, data):
115+
g = Grid(data['griddata'], origin=data['origin'] + 1, delta=data['delta'])
116+
with pytest.raises(TypeError):
117+
data['grid'].check_compatible(g)
118+
119+
arr = np.zeros(data['griddata'].shape[-1] + 1) # Not broadcastable
114120
with pytest.raises(TypeError):
115-
data['grid'].check_compatible("foo")
121+
data['grid'].check_compatible(arr)
116122

117123
def test_non_orthonormal_boxes(self, data):
118124
delta = np.eye(3)

0 commit comments

Comments
 (0)