Skip to content

Commit 5e3ec47

Browse files
integration test
1 parent 66ae37d commit 5e3ec47

File tree

2 files changed

+10
-0
lines changed

2 files changed

+10
-0
lines changed

test_elasticsearch/test_dsl/test_integration/_async/test_document.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ipaddress import ip_address
2626
from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, List, Optional, Tuple, Union
2727

28+
import numpy as np
2829
import pytest
2930
from pytest import raises
3031
from pytz import timezone
@@ -865,6 +866,7 @@ class Doc(AsyncDocument):
865866
float_vector: List[float] = mapped_field(DenseVector())
866867
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
867868
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
869+
numpy_float_vector: np.ndarray = mapped_field(DenseVector(use_numpy=True))
868870

869871
class Index:
870872
name = "vectors"
@@ -876,6 +878,7 @@ class Index:
876878
float_vector=[1.0, 1.2, 2.3],
877879
byte_vector=[12, 23, 34, 45],
878880
bit_vector=[18, -43, -112],
881+
numpy_float_vector=np.array([3.1, 2.25, 1.0]),
879882
)
880883
await doc.save(refresh=True)
881884

@@ -884,6 +887,8 @@ class Index:
884887
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
885888
assert docs[0].byte_vector == doc.byte_vector
886889
assert docs[0].bit_vector == doc.bit_vector
890+
assert type(docs[0].numpy_float_vector) is np.ndarray
891+
assert np.array_equal(docs[0].numpy_float_vector, doc.numpy_float_vector)
887892

888893

889894
@pytest.mark.anyio

test_elasticsearch/test_dsl/test_integration/_sync/test_document.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from ipaddress import ip_address
2626
from typing import TYPE_CHECKING, Any, Dict, Iterator, List, Optional, Tuple, Union
2727

28+
import numpy as np
2829
import pytest
2930
from pytest import raises
3031
from pytz import timezone
@@ -853,6 +854,7 @@ class Doc(Document):
853854
float_vector: List[float] = mapped_field(DenseVector())
854855
byte_vector: List[int] = mapped_field(DenseVector(element_type="byte"))
855856
bit_vector: List[int] = mapped_field(DenseVector(element_type="bit"))
857+
numpy_float_vector: np.ndarray = mapped_field(DenseVector(use_numpy=True))
856858

857859
class Index:
858860
name = "vectors"
@@ -864,6 +866,7 @@ class Index:
864866
float_vector=[1.0, 1.2, 2.3],
865867
byte_vector=[12, 23, 34, 45],
866868
bit_vector=[18, -43, -112],
869+
numpy_float_vector=np.array([3.1, 2.25, 1.0]),
867870
)
868871
doc.save(refresh=True)
869872

@@ -872,6 +875,8 @@ class Index:
872875
assert [round(v, 1) for v in docs[0].float_vector] == doc.float_vector
873876
assert docs[0].byte_vector == doc.byte_vector
874877
assert docs[0].bit_vector == doc.bit_vector
878+
assert type(docs[0].numpy_float_vector) is np.ndarray
879+
assert np.array_equal(docs[0].numpy_float_vector, doc.numpy_float_vector)
875880

876881

877882
@pytest.mark.sync

0 commit comments

Comments
 (0)