Skip to content

Commit 9e1c421

Browse files
committed
Added docs and test for half-precision indexing with Django
1 parent 49072f2 commit 9e1c421

File tree

2 files changed

+49
-1
lines changed

2 files changed

+49
-1
lines changed

README.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,31 @@ class Item(models.Model):
133133

134134
Use `vector_ip_ops` for inner product and `vector_cosine_ops` for cosine distance
135135

136+
#### Half-Precision Indexing
137+
138+
Index vectors at half-precision
139+
140+
```python
141+
from django.contrib.postgres.indexes import OpClass
142+
from django.db.models.functions import Cast
143+
from pgvector.django import HalfVectorField
144+
145+
index = HnswIndex(
146+
OpClass(Cast('embedding', HalfVectorField(dimensions=3)), name='halfvec_l2_ops'),
147+
name='my_index',
148+
m=16,
149+
ef_construction=64
150+
)
151+
```
152+
153+
Note: Add `'django.contrib.postgres'` to `INSTALLED_APPS` to use `OpClass`
154+
155+
Get the nearest neighbors
156+
157+
```python
158+
Item.objects.order_by(L2Distance(Cast('embedding', HalfVectorField(dimensions=3)), [3, 1, 2]))[:5]
159+
```
160+
136161
## SQLAlchemy
137162

138163
Enable the extension

tests/test_django.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import django
22
from django.conf import settings
33
from django.contrib.postgres.fields import ArrayField
4+
from django.contrib.postgres.indexes import OpClass
45
from django.core import serializers
56
from django.db import connection, migrations, models
67
from django.db.models import Avg, Sum, FloatField, DecimalField
@@ -38,7 +39,12 @@
3839
'level': 'WARNING'
3940
}
4041
}
41-
}
42+
},
43+
# needed for OpClass
44+
# https://docs.djangoproject.com/en/5.1/ref/contrib/postgres/indexes/#opclass-expressions
45+
INSTALLED_APPS=[
46+
'django.contrib.postgres'
47+
]
4248
)
4349
django.setup()
4450

@@ -67,6 +73,12 @@ class Meta:
6773
m=16,
6874
ef_construction=64,
6975
opclasses=['vector_l2_ops']
76+
),
77+
HnswIndex(
78+
OpClass(Cast('embedding', HalfVectorField(dimensions=3)), name='halfvec_l2_ops'),
79+
name='hnsw_half_precision_idx',
80+
m=16,
81+
ef_construction=64
7082
)
7183
]
7284

@@ -99,6 +111,10 @@ class Migration(migrations.Migration):
99111
migrations.AddIndex(
100112
model_name='item',
101113
index=pgvector.django.HnswIndex(fields=['embedding'], m=16, ef_construction=64, name='hnsw_idx', opclasses=['vector_l2_ops']),
114+
),
115+
migrations.AddIndex(
116+
model_name='item',
117+
index=pgvector.django.HnswIndex(OpClass(Cast('embedding', HalfVectorField(dimensions=3)), name='halfvec_l2_ops'), m=16, ef_construction=64, name='hnsw_half_precision_idx'),
102118
)
103119
]
104120

@@ -473,3 +489,10 @@ def test_numeric_array(self):
473489
assert [v.id for v in items] == [1, 3, 2]
474490
assert [v.distance for v in items] == [0, 1, sqrt(3)]
475491
assert items[1].numeric_embedding == [1, 1, 2]
492+
493+
def test_half_precision(self):
494+
create_items()
495+
distance = L2Distance(Cast('embedding', HalfVectorField(dimensions=3)), [1, 1, 1])
496+
items = Item.objects.annotate(distance=distance).order_by(distance)
497+
assert [v.id for v in items] == [1, 3, 2]
498+
assert [v.distance for v in items] == [0, 1, sqrt(3)]

0 commit comments

Comments
 (0)