Skip to content

Commit 0df64c0

Browse files
committed
Support equality on multidimensional arrays
1 parent 2bb8d24 commit 0df64c0

File tree

2 files changed

+78
-3
lines changed

2 files changed

+78
-3
lines changed

Src/IronPython/Runtime/Operations/ArrayOps.cs

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,38 @@ public static object __eq__(CodeContext context, Array self, [NotNone] Array oth
9393
if (other is null) throw PythonOps.TypeError("expected Array, got None");
9494

9595
if (self.GetType() != other.GetType()) return ScriptingRuntimeHelpers.False;
96+
// same type implies: same rank, same element type
97+
for (int d = 0; d < self.Rank; d++) {
98+
if (self.GetLowerBound(d) != other.GetLowerBound(d)) return ScriptingRuntimeHelpers.False;
99+
if (self.GetUpperBound(d) != other.GetUpperBound(d)) return ScriptingRuntimeHelpers.False;
100+
}
101+
if (self.Length == 0) return ScriptingRuntimeHelpers.True; // fast track
96102

97-
return ScriptingRuntimeHelpers.BooleanToObject(
98-
((IStructuralEquatable)self).Equals(other, context.LanguageContext.EqualityComparerNonGeneric)
99-
);
103+
if (self.Rank == 1 && self.GetLowerBound(0) == 0 ) {
104+
// IStructuralEquatable.Equals only works for 1-dim, 0-based arrays
105+
return ScriptingRuntimeHelpers.BooleanToObject(
106+
((IStructuralEquatable)self).Equals(other, context.LanguageContext.EqualityComparerNonGeneric)
107+
);
108+
} else {
109+
int[] ix = new int[self.Rank];
110+
for (int d = 0; d < self.Rank; d++) {
111+
ix[d] = self.GetLowerBound(d);
112+
}
113+
for (int i = 0; i < self.Length; i++) {
114+
if (!PythonOps.EqualRetBool(self.GetValue(ix), other.GetValue(ix))) {
115+
return ScriptingRuntimeHelpers.False;
116+
}
117+
for (int d = self.Rank - 1; d >= 0; d--) {
118+
if (ix[d] < self.GetUpperBound(d)) {
119+
ix[d]++;
120+
break;
121+
} else {
122+
ix[d] = self.GetLowerBound(d);
123+
}
124+
}
125+
}
126+
return ScriptingRuntimeHelpers.True;
127+
}
100128
}
101129

102130
[StaticExtensionMethod]

Tests/test_array.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,4 +318,51 @@ def test_equality(self):
318318
self.assertTrue(a != l)
319319
self.assertTrue(l != a)
320320

321+
def test_equality_base(self):
322+
a = System.Array.CreateInstance(int, (5,), (5,))
323+
a2 = System.Array.CreateInstance(int, (5,), (5,))
324+
b = System.Array.CreateInstance(int, (6,), (5,))
325+
c = System.Array.CreateInstance(int, (5,), (6,))
326+
d = System.Array.CreateInstance(int, (6,), (6,))
327+
328+
self.assertTrue(a == a2)
329+
self.assertFalse(a == b)
330+
self.assertFalse(a == c)
331+
self.assertFalse(a == d)
332+
333+
self.assertFalse(a != a2)
334+
self.assertTrue(a != b)
335+
self.assertTrue(a != c)
336+
self.assertTrue(a != d)
337+
338+
def test_equality_rank(self):
339+
a = System.Array.CreateInstance(int, 5, 6)
340+
a2 = System.Array.CreateInstance(int, 5, 6)
341+
b = System.Array.CreateInstance(int, 5, 6)
342+
b[0, 0] = 1
343+
c = System.Array.CreateInstance(int, (6, 5), (0, 0))
344+
c[0, 0] = 1
345+
d = System.Array.CreateInstance(int, (6, 5), (1, 1))
346+
d[1, 1] = 1
347+
d1 = System.Array.CreateInstance(int, (6, 5), (1, 1))
348+
d1[1, 1] = 1
349+
350+
self.assertTrue(a == a2)
351+
self.assertFalse(a == b) # different element
352+
self.assertFalse(a == c) # different rank
353+
self.assertFalse(a == d) # different rank
354+
self.assertFalse(b == c) # different shape
355+
self.assertFalse(b == d) # different shape & base
356+
self.assertFalse(c == d) # different base
357+
self.assertTrue(d == d1)
358+
359+
self.assertFalse(a != a2)
360+
self.assertTrue(a != b)
361+
self.assertTrue(a != c)
362+
self.assertTrue(a != d)
363+
self.assertTrue(b != c)
364+
self.assertTrue(b != d)
365+
self.assertTrue(c != d)
366+
self.assertFalse(d != d1)
367+
321368
run_test(__name__)

0 commit comments

Comments
 (0)