File tree Expand file tree Collapse file tree 2 files changed +8
-2
lines changed
Expand file tree Collapse file tree 2 files changed +8
-2
lines changed Original file line number Diff line number Diff line change @@ -1556,6 +1556,7 @@ class DenseVector(Field):
15561556 :arg fields:
15571557 :arg synthetic_source_keep:
15581558 :arg use_numpy: if set to ``True``, deserialize as a numpy array.
1559+ :arg dtype: The numpy data type to use as a string, when ``use_numpy`` is ``True``. The default is "float32".
15591560 """
15601561
15611562 name = "dense_vector"
@@ -1589,6 +1590,7 @@ def __init__(
15891590 Literal ["none" , "arrays" , "all" ], "DefaultType"
15901591 ] = DEFAULT ,
15911592 use_numpy : bool = False ,
1593+ dtype : str = "float32" ,
15921594 ** kwargs : Any ,
15931595 ):
15941596 if dims is not DEFAULT :
@@ -1617,13 +1619,14 @@ def __init__(
16171619 if self ._element_type in ["float" , "byte" ]:
16181620 kwargs ["multi" ] = True
16191621 self ._use_numpy = use_numpy
1622+ self ._dtype = dtype
16201623 super ().__init__ (* args , ** kwargs )
16211624
16221625 def deserialize (self , data : Any ) -> Any :
16231626 if self ._use_numpy and isinstance (data , list ):
16241627 import numpy as np
16251628
1626- return np .array (data )
1629+ return np .array (data , dtype = getattr ( np , self . _dtype ) )
16271630 return super ().deserialize (data )
16281631
16291632 def clean (self , data : Any ) -> Any :
Original file line number Diff line number Diff line change @@ -219,6 +219,7 @@ class {{ k.name }}({{ k.parent }}):
219219 {% endif %}
220220 {% if k.field == " dense_vector" %}
221221 :arg use_numpy: if set to ``True``, deserialize as a numpy array.
222+ :arg dtype: The numpy data type to use as a string, when ``use_numpy`` is ``True``. The default is "float32".
222223 {% endif %}
223224 """
224225 name = "{ { k.field } }"
@@ -251,6 +252,7 @@ class {{ k.name }}({{ k.parent }}):
251252 {% endfor %}
252253 {% if k.field == " dense_vector" %}
253254 use_numpy: bool = False,
255+ dtype: str = "float32",
254256 {% endif %}
255257 **kwargs: Any
256258 ):
@@ -423,12 +425,13 @@ class {{ k.name }}({{ k.parent }}):
423425 if self._element_type in ["float", "byte"]:
424426 kwargs["multi"] = True
425427 self._use_numpy = use_numpy
428+ self._dtype = dtype
426429 super().__init__(*args, **kwargs)
427430
428431 def deserialize(self, data: Any) -> Any:
429432 if self._use_numpy and isinstance(data, list):
430433 import numpy as np
431- return np.array(data)
434+ return np.array(data, dtype=getattr(np, self._dtype) )
432435 return super().deserialize(data)
433436
434437 def clean(self, data: Any) -> Any:
You can’t perform that action at this time.
0 commit comments