Skip to content

Commit 195f2ce

Browse files
Support dense vectors based on numpy arrays
1 parent bbca81a commit 195f2ce

File tree

4 files changed

+14
-12
lines changed

4 files changed

+14
-12
lines changed

elasticsearch/dsl/field.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,11 +1616,11 @@ def __init__(
16161616
kwargs["multi"] = True
16171617
super().__init__(*args, **kwargs)
16181618

1619-
def _deserialize(self, data: Any) -> Any:
1620-
if self._element_type == "float":
1621-
return float(data)
1622-
elif self._element_type == "byte":
1623-
return int(data)
1619+
def clean(self, data: Any) -> Any:
1620+
if data is not None:
1621+
data = self.deserialize(data)
1622+
if (data is None or len(data) == 0) and self._required:
1623+
raise ValidationException("Value required for this field.")
16241624
return data
16251625

16261626

elasticsearch/dsl/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,9 @@ def to_dict(self, skip_empty: bool = True) -> Dict[str, Any]:
612612
if skip_empty:
613613
# don't serialize empty values
614614
# careful not to include numeric zeros
615-
if v in ([], {}, None):
615+
# the "is" operator is used below because it is the only comparison
616+
# that works for numpy arrays
617+
if v is [] or v is {} or v is None:
616618
continue
617619

618620
out[k] = v

examples/quotes/backend/quotes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ async def search_quotes(req: SearchRequest) -> SearchResponse:
135135
def embed_quotes(quotes):
136136
embeddings = model.encode([q.quote for q in quotes])
137137
for q, e in zip(quotes, embeddings):
138-
q.embedding = e.tolist()
138+
q.embedding = e
139139

140140

141141
async def ingest_quotes():

utils/templates/field.py.tpl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -418,11 +418,11 @@ class {{ k.name }}({{ k.parent }}):
418418
kwargs["multi"] = True
419419
super().__init__(*args, **kwargs)
420420

421-
def _deserialize(self, data: Any) -> Any:
422-
if self._element_type == "float":
423-
return float(data)
424-
elif self._element_type == "byte":
425-
return int(data)
421+
def clean(self, data: Any) -> Any:
422+
if data is not None:
423+
data = self.deserialize(data)
424+
if (data is None or len(data) == 0) and self._required:
425+
raise ValidationException("Value required for this field.")
426426
return data
427427
{% elif k.field == "scaled_float" %}
428428
if 'scaling_factor' not in kwargs:

0 commit comments

Comments
 (0)