Skip to content

Commit 226f22a

Browse files
wip
1 parent 42f1834 commit 226f22a

File tree

3 files changed

+22
-17
lines changed

3 files changed

+22
-17
lines changed

elasticsearch/dsl/utils.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -609,11 +609,11 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
609609
if isinstance(v, AttrList):
610610
v = v._l_
611611

612-
if skip_empty:
613-
# don't serialize empty values
614-
# careful not to include numeric zeros
615-
if v in ([], {}, None):
616-
continue
612+
# if skip_empty:
613+
# # don't serialize empty values
614+
# # careful not to include numeric zeros
615+
# if v in ([], {}, None):
616+
# continue
617617

618618
out[k] = v
619619
return out

examples/quotes/backend/quotes.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import base64
23
import csv
34
import os
45
from time import time
@@ -8,19 +9,23 @@
89
from pydantic import BaseModel, Field, ValidationError
910
from sentence_transformers import SentenceTransformer
1011

11-
from elasticsearch import NotFoundError
12+
from elasticsearch import NotFoundError, OrjsonSerializer
1213
from elasticsearch.dsl.pydantic import AsyncBaseESModel
1314
from elasticsearch import dsl
15+
from elasticsearch.dsl.types import DenseVectorIndexOptions
1416

1517
model = SentenceTransformer("all-MiniLM-L6-v2")
16-
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']])
18+
dsl.async_connections.create_connection(hosts=[os.environ['ELASTICSEARCH_URL']], serializer=OrjsonSerializer()
19+
)
1720

1821

1922
class Quote(AsyncBaseESModel):
2023
quote: str
2124
author: Annotated[str, dsl.Keyword()]
2225
tags: Annotated[list[str], dsl.Keyword()]
23-
embedding: Annotated[list[float], dsl.DenseVector()] = Field(init=False, default=[])
26+
embedding: Annotated[list[float], dsl.DenseVector(
27+
index_options=DenseVectorIndexOptions(type="flat"),
28+
)] = Field(init=False, default=[])
2429

2530
class Index:
2631
name = 'quotes'
@@ -135,7 +140,10 @@ async def search_quotes(req: SearchRequest) -> SearchResponse:
135140
def embed_quotes(quotes):
136141
embeddings = model.encode([q.quote for q in quotes])
137142
for q, e in zip(quotes, embeddings):
138-
q.embedding = e.tolist()
143+
q.embedding = e
144+
# q.embedding = e.tolist()
145+
##byte_array = e.byteswap().tobytes()
146+
##q.embedding = base64.b64encode(byte_array).decode()
139147

140148

141149
async def ingest_quotes():

utils/templates/field.py.tpl

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,11 @@ class Field(DslBase):
159159
def clean(self, data: Any) -> Any:
160160
if data is not None:
161161
data = self.deserialize(data)
162-
if data in (None, [], {}) and self._required:
162+
# the "data is ..." comparisons below work well when data is a numpy
163+
# array (only for dense vector fields)
164+
# unfortunately numpy overrides the == operator in a way that causes
165+
# errors when used instead of "is"
166+
if (data is None or data is [] or data is {}) and self._required:
163167
raise ValidationException("Value required for this field.")
164168
return data
165169

@@ -417,13 +421,6 @@ class {{ k.name }}({{ k.parent }}):
417421
if self._element_type in ["float", "byte"]:
418422
kwargs["multi"] = True
419423
super().__init__(*args, **kwargs)
420-
421-
def _deserialize(self, data: Any) -> Any:
422-
if self._element_type == "float":
423-
return float(data)
424-
elif self._element_type == "byte":
425-
return int(data)
426-
return data
427424
{% elif k.field == "scaled_float" %}
428425
if 'scaling_factor' not in kwargs:
429426
if len(args) > 0:

0 commit comments

Comments
 (0)