|
1 | 1 | import django |
2 | 2 | from django.conf import settings |
3 | 3 | from django.contrib.postgres.fields import ArrayField |
| 4 | +from django.contrib.postgres.indexes import OpClass |
4 | 5 | from django.core import serializers |
5 | 6 | from django.db import connection, migrations, models |
6 | 7 | from django.db.models import Avg, Sum, FloatField, DecimalField |
|
38 | 39 | 'level': 'WARNING' |
39 | 40 | } |
40 | 41 | } |
41 | | - } |
| 42 | + }, |
| 43 | + # needed for OpClass |
| 44 | + # https://docs.djangoproject.com/en/5.1/ref/contrib/postgres/indexes/#opclass-expressions |
| 45 | + INSTALLED_APPS=[ |
| 46 | + 'django.contrib.postgres' |
| 47 | + ] |
42 | 48 | ) |
43 | 49 | django.setup() |
44 | 50 |
|
@@ -67,6 +73,12 @@ class Meta: |
67 | 73 | m=16, |
68 | 74 | ef_construction=64, |
69 | 75 | opclasses=['vector_l2_ops'] |
| 76 | + ), |
| 77 | + HnswIndex( |
| 78 | + OpClass(Cast('embedding', HalfVectorField(dimensions=3)), name='halfvec_l2_ops'), |
| 79 | + name='hnsw_half_precision_idx', |
| 80 | + m=16, |
| 81 | + ef_construction=64 |
70 | 82 | ) |
71 | 83 | ] |
72 | 84 |
|
@@ -99,6 +111,10 @@ class Migration(migrations.Migration): |
99 | 111 | migrations.AddIndex( |
100 | 112 | model_name='item', |
101 | 113 | index=pgvector.django.HnswIndex(fields=['embedding'], m=16, ef_construction=64, name='hnsw_idx', opclasses=['vector_l2_ops']), |
| 114 | + ), |
| 115 | + migrations.AddIndex( |
| 116 | + model_name='item', |
| 117 | + index=pgvector.django.HnswIndex(OpClass(Cast('embedding', HalfVectorField(dimensions=3)), name='halfvec_l2_ops'), m=16, ef_construction=64, name='hnsw_half_precision_idx'), |
102 | 118 | ) |
103 | 119 | ] |
104 | 120 |
|
@@ -473,3 +489,10 @@ def test_numeric_array(self): |
473 | 489 | assert [v.id for v in items] == [1, 3, 2] |
474 | 490 | assert [v.distance for v in items] == [0, 1, sqrt(3)] |
475 | 491 | assert items[1].numeric_embedding == [1, 1, 2] |
| 492 | + |
| 493 | + def test_half_precision(self): |
| 494 | + create_items() |
| 495 | + distance = L2Distance(Cast('embedding', HalfVectorField(dimensions=3)), [1, 1, 1]) |
| 496 | + items = Item.objects.annotate(distance=distance).order_by(distance) |
| 497 | + assert [v.id for v in items] == [1, 3, 2] |
| 498 | + assert [v.distance for v in items] == [0, 1, sqrt(3)] |
0 commit comments