Skip to content

Commit 2929350

Browse files
author
Diptorup Deb
committed
Add NdRange.__eq__ method to make object equality checks possible.
1 parent 601789a commit 2929350

File tree

3 files changed

+10
-24
lines changed

3 files changed

+10
-24
lines changed

numba_dpex/core/kernel_interface/indexers.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,12 @@ def __repr__(self):
219219
"""
220220
return self.__str__()
221221

222+
def __eq__(self, other):
223+
return (
224+
self.global_range == other.global_range
225+
and self.local_range == other.local_range
226+
)
227+
222228

223229
@intrinsic
224230
def _intrin_range_alloc(typingctx, ty_dim0, ty_dim1, ty_dim2, ty_range):

numba_dpex/tests/core/types/range_types/test_constructor_overloads.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ def _tester(r):
1818
r_expected = Range(*r)
1919
r_out = _tester(r)
2020

21-
assert r_out.ndim == r_expected.ndim
22-
assert r_out.dim0 == r_expected.dim0
23-
assert r_out.dim1 == r_expected.dim1
24-
assert r_out.dim2 == r_expected.dim2
21+
assert r_out == r_expected
2522

2623

2724
@pytest.mark.parametrize("r", ranges)
@@ -35,11 +32,4 @@ def _tester(r):
3532
r_expected = NdRange(gr, lr)
3633
r_out = _tester(r)
3734

38-
assert r_out.global_range.ndim == r_expected.global_range.ndim
39-
assert r_out.local_range.ndim == r_expected.local_range.ndim
40-
assert r_out.global_range.dim0 == r_expected.global_range.dim0
41-
assert r_out.global_range.dim1 == r_expected.global_range.dim1
42-
assert r_out.global_range.dim2 == r_expected.global_range.dim2
43-
assert r_out.local_range.dim0 == r_expected.local_range.dim0
44-
assert r_out.local_range.dim1 == r_expected.local_range.dim1
45-
assert r_out.local_range.dim2 == r_expected.local_range.dim2
35+
assert r_out == r_expected

numba_dpex/tests/core/types/range_types/test_unbox_box.py

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,7 @@ def _tester(r):
1818
r_in = Range(*r)
1919
r_out = _tester(r_in)
2020

21-
assert r_out.ndim == r_in.ndim
22-
assert r_out.dim0 == r_in.dim0
23-
assert r_out.dim1 == r_in.dim1
24-
assert r_out.dim2 == r_in.dim2
21+
assert r_out == r_in
2522

2623

2724
@pytest.mark.parametrize("r", ranges)
@@ -34,11 +31,4 @@ def _tester(r):
3431
r_in = NdRange(gr, lr)
3532
r_out = _tester(r_in)
3633

37-
assert r_out.global_range.ndim == r_in.global_range.ndim
38-
assert r_out.local_range.ndim == r_in.local_range.ndim
39-
assert r_out.global_range.dim0 == r_in.global_range.dim0
40-
assert r_out.global_range.dim1 == r_in.global_range.dim1
41-
assert r_out.global_range.dim2 == r_in.global_range.dim2
42-
assert r_out.local_range.dim0 == r_in.local_range.dim0
43-
assert r_out.local_range.dim1 == r_in.local_range.dim1
44-
assert r_out.local_range.dim2 == r_in.local_range.dim2
34+
assert r_out == r_in

0 commit comments

Comments
 (0)