Skip to content

Commit e138e55

Browse files
committed
Fixed equality for types [skip ci]
1 parent c10799c commit e138e55

File tree

9 files changed

+41
-0
lines changed

9 files changed

+41
-0
lines changed

pgvector/django/functions.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ def __init__(self, expression, vector, **extra):
1313
vector = Value(SparseVector._to_db(vector))
1414
else:
1515
vector = Value(Vector._to_db(vector))
16+
17+
# prevent error with unhashable types
18+
self._constructor_args = ((expression, vector), extra)
19+
1620
super().__init__(expression, vector, **extra)
1721

1822

pgvector/utils/bit.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@ def __init__(self, value):
2424
def __repr__(self):
2525
return f'Bit({self.to_text()})'
2626

27+
def __eq__(self, other):
28+
if isinstance(other, self.__class__):
29+
return np.array_equal(self.to_numpy(), other.to_numpy())
30+
return False
31+
2732
def to_list(self):
2833
return self._value.tolist()
2934

pgvector/utils/halfvec.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def __init__(self, value):
1616
def __repr__(self):
1717
return f'HalfVector({self.to_list()})'
1818

19+
def __eq__(self, other):
20+
if isinstance(other, self.__class__):
21+
return np.array_equal(self.to_numpy(), other.to_numpy())
22+
return False
23+
1924
def dimensions(self):
2025
return len(self._value)
2126

pgvector/utils/sparsevec.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,11 @@ def __repr__(self):
2626
elements = dict(zip(self._indices, self._values))
2727
return f'SparseVector({elements}, {self._dim})'
2828

29+
def __eq__(self, other):
30+
if isinstance(other, self.__class__):
31+
return self.dimensions() == other.dimensions() and self.indices() == other.indices() and self.values() == other.values()
32+
return False
33+
2934
def dimensions(self):
3035
return self._dim
3136

pgvector/utils/vector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,11 @@ def __init__(self, value):
1616
def __repr__(self):
1717
return f'Vector({self.to_list()})'
1818

19+
def __eq__(self, other):
20+
if isinstance(other, self.__class__):
21+
return np.array_equal(self.to_numpy(), other.to_numpy())
22+
return False
23+
1924
def dimensions(self):
2025
return len(self._value)
2126

tests/test_bit.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,7 @@ def test_ndim_zero(self):
3737
def test_repr(self):
3838
assert repr(Bit([True, False, True])) == 'Bit(101)'
3939
assert str(Bit([True, False, True])) == 'Bit(101)'
40+
41+
def test_equality(self):
42+
assert Bit([True, False, True]) == Bit([True, False, True])
43+
assert Bit([True, False, True]) != Bit([True, False, False])

tests/test_half_vector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,9 @@ def test_repr(self):
3838
assert repr(HalfVector([1, 2, 3])) == 'HalfVector([1.0, 2.0, 3.0])'
3939
assert str(HalfVector([1, 2, 3])) == 'HalfVector([1.0, 2.0, 3.0])'
4040

41+
def test_equality(self):
42+
assert HalfVector([1, 2, 3]) == HalfVector([1, 2, 3])
43+
assert HalfVector([1, 2, 3]) != HalfVector([1, 2, 4])
44+
4145
def test_dimensions(self):
4246
assert HalfVector([1, 2, 3]).dimensions() == 3

tests/test_sparse_vector.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,11 @@ def test_repr(self):
5252
assert repr(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)'
5353
assert str(SparseVector([1, 0, 2, 0, 3, 0])) == 'SparseVector({0: 1.0, 2: 2.0, 4: 3.0}, 6)'
5454

55+
def test_equality(self):
56+
assert SparseVector([1, 0, 2, 0, 3, 0]) == SparseVector([1, 0, 2, 0, 3, 0])
57+
assert SparseVector([1, 0, 2, 0, 3, 0]) != SparseVector([1, 0, 2, 0, 3, 1])
58+
assert SparseVector([1, 0, 2, 0, 3, 0]) == SparseVector({2: 2, 4: 3, 0: 1, 3: 0}, 6)
59+
5560
def test_dimensions(self):
5661
assert SparseVector([1, 0, 2, 0, 3, 0]).dimensions() == 6
5762

tests/test_vector.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,5 +38,9 @@ def test_repr(self):
3838
assert repr(Vector([1, 2, 3])) == 'Vector([1.0, 2.0, 3.0])'
3939
assert str(Vector([1, 2, 3])) == 'Vector([1.0, 2.0, 3.0])'
4040

41+
def test_equality(self):
42+
assert Vector([1, 2, 3]) == Vector([1, 2, 3])
43+
assert Vector([1, 2, 3]) != Vector([1, 2, 4])
44+
4145
def test_dimensions(self):
4246
assert Vector([1, 2, 3]).dimensions() == 3

0 commit comments

Comments
 (0)