Skip to content

Commit d4187c2

Browse files
add optional dtype argument
1 parent 5e3ec47 commit d4187c2

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

elasticsearch/dsl/field.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1556,6 +1556,7 @@ class DenseVector(Field):
15561556
:arg fields:
15571557
:arg synthetic_source_keep:
15581558
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
1559+
:arg dtype: The numpy data type to use as a string, when ``use_numpy`` is ``True``. The default is "float32".
15591560
"""
15601561

15611562
name = "dense_vector"
@@ -1589,6 +1590,7 @@ def __init__(
15891590
Literal["none", "arrays", "all"], "DefaultType"
15901591
] = DEFAULT,
15911592
use_numpy: bool = False,
1593+
dtype: str = "float32",
15921594
**kwargs: Any,
15931595
):
15941596
if dims is not DEFAULT:
@@ -1617,13 +1619,14 @@ def __init__(
16171619
if self._element_type in ["float", "byte"]:
16181620
kwargs["multi"] = True
16191621
self._use_numpy = use_numpy
1622+
self._dtype = dtype
16201623
super().__init__(*args, **kwargs)
16211624

16221625
def deserialize(self, data: Any) -> Any:
16231626
if self._use_numpy and isinstance(data, list):
16241627
import numpy as np
16251628

1626-
return np.array(data)
1629+
return np.array(data, dtype=getattr(np, self._dtype))
16271630
return super().deserialize(data)
16281631

16291632
def clean(self, data: Any) -> Any:

utils/templates/field.py.tpl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,7 @@ class {{ k.name }}({{ k.parent }}):
219219
{% endif %}
220220
{% if k.field == "dense_vector" %}
221221
:arg use_numpy: if set to ``True``, deserialize as a numpy array.
222+
:arg dtype: The numpy data type to use as a string, when ``use_numpy`` is ``True``. The default is "float32".
222223
{% endif %}
223224
"""
224225
name = "{{ k.field }}"
@@ -251,6 +252,7 @@ class {{ k.name }}({{ k.parent }}):
251252
{% endfor %}
252253
{% if k.field == "dense_vector" %}
253254
use_numpy: bool = False,
255+
dtype: str = "float32",
254256
{% endif %}
255257
**kwargs: Any
256258
):
@@ -423,12 +425,13 @@ class {{ k.name }}({{ k.parent }}):
423425
if self._element_type in ["float", "byte"]:
424426
kwargs["multi"] = True
425427
self._use_numpy = use_numpy
428+
self._dtype = dtype
426429
super().__init__(*args, **kwargs)
427430

428431
def deserialize(self, data: Any) -> Any:
429432
if self._use_numpy and isinstance(data, list):
430433
import numpy as np
431-
return np.array(data)
434+
return np.array(data, dtype=getattr(np, self._dtype))
432435
return super().deserialize(data)
433436

434437
def clean(self, data: Any) -> Any:

0 commit comments

Comments
 (0)