Skip to content

Commit af28f77

Browse files
authored
__eq__ method for Corr class (#206)
* feat: implemented __eq__ method for Corr class. * feat: __eq__ method now respects None entries in correlators. * feat: Obs can now be compared to None, __ne__ method removed as it is not required. * feat: Corr.__eq__ rewritten to give a per element comparison. * tests: additional test case for correlator comparison added. * feat: comparison now also works for padding.
1 parent 1e43835 commit af28f77

File tree

4 files changed

+46
-3
lines changed

4 files changed

+46
-3
lines changed

pyerrors/correlators.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1056,6 +1056,13 @@ def __str__(self):
10561056

10571057
__array_priority__ = 10000
10581058

1059+
def __eq__(self, y):
1060+
if isinstance(y, Corr):
1061+
comp = np.asarray(y.content, dtype=object)
1062+
else:
1063+
comp = np.asarray(y)
1064+
return np.asarray(self.content, dtype=object) == comp
1065+
10591066
def __add__(self, y):
10601067
if isinstance(y, Corr):
10611068
if ((self.N != y.N) or (self.T != y.T)):

pyerrors/obs.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -773,11 +773,10 @@ def __ge__(self, other):
773773
return self.value >= other
774774

775775
def __eq__(self, other):
776+
if other is None:
777+
return False
776778
return (self - other).is_zero()
777779

778-
def __ne__(self, other):
779-
return not (self - other).is_zero()
780-
781780
# Overload math operations
782781
def __add__(self, y):
783782
if isinstance(y, Obs):

tests/correlators_test.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -713,3 +713,39 @@ def test_corr_roll():
713713
tt = mcorr.roll(T) - mcorr
714714
for el in tt:
715715
assert np.all(el == 0)
716+
717+
718+
def test_correlator_comparison():
719+
scorr = pe.Corr([pe.pseudo_Obs(0.3, 0.1, "test") for o in range(4)])
720+
mcorr = pe.Corr(np.array([[scorr, scorr], [scorr, scorr]]))
721+
for corr in [scorr, mcorr]:
722+
assert (corr == corr).all()
723+
assert np.all(corr == 1 * corr)
724+
assert np.all(corr == (1 + 1e-16) * corr)
725+
assert not np.all(corr == (1 + 1e-5) * corr)
726+
assert np.all(corr == 1 / (1 / corr))
727+
assert np.all(corr - corr == 0)
728+
assert np.all(corr * 0 == 0)
729+
assert np.all(0 * corr == 0)
730+
assert np.all(0 * corr + scorr[2] == scorr[2])
731+
assert np.all(-corr == 0 - corr)
732+
assert np.all(corr ** 2 == corr * corr)
733+
acorr = pe.Corr([scorr[0]] * 6)
734+
assert np.all(acorr == scorr[0])
735+
assert not np.all(acorr == scorr[1])
736+
737+
mcorr[1][0, 1] = None
738+
assert not np.all(mcorr == pe.Corr(np.array([[scorr, scorr], [scorr, scorr]])))
739+
740+
pcorr = pe.Corr([pe.pseudo_Obs(0.25, 0.1, "test") for o in range(2)], padding=[1, 1])
741+
assert np.all(pcorr == pcorr)
742+
assert np.all(1 * pcorr == pcorr)
743+
744+
745+
def test_corr_item():
746+
corr_aa = _gen_corr(1)
747+
corr_ab = 0.5 * corr_aa
748+
749+
corr_mat = pe.Corr(np.array([[corr_aa, corr_ab], [corr_ab, corr_aa]]))
750+
corr_mat.item(0, 0)
751+
assert corr_mat[0].item(0, 1) == corr_mat.item(0, 1)[0]

tests/obs_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ def test_comparison():
103103
test_obs1 = pe.pseudo_Obs(value1, 0.1, 't')
104104
value2 = np.random.normal(0, 100)
105105
test_obs2 = pe.pseudo_Obs(value2, 0.1, 't')
106+
assert test_obs1 != None
106107
assert (value1 > value2) == (test_obs1 > test_obs2)
107108
assert (value1 < value2) == (test_obs1 < test_obs2)
108109
assert (value1 >= value2) == (test_obs1 >= test_obs2)

0 commit comments

Comments
 (0)