Skip to content

Commit 88bc083

Browse files
[Backport 2.6] float scalar datatype supports floating-point numbers (#3172) (#3173)
Backport of #3172 to `2.6`. Signed-off-by: wangting0128 <ting.wang@zilliz.com> Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: wt <ting.wang@zilliz.com>
1 parent 0adee89 commit 88bc083

File tree

2 files changed

+7
-5
lines changed

2 files changed

+7
-5
lines changed

pymilvus/orm/schema.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,9 @@ def __init__(self, name: str, dtype: DataType, description: str = "", **kwargs)
482482
if self.default_value.WhichOneof("data") is None:
483483
self.default_value = None
484484
else:
485-
self.default_value = infer_default_value_bydata(kwargs.get("default_value"))
485+
self.default_value = infer_default_value_bydata(
486+
kwargs.get("default_value"), dtype=self._dtype
487+
)
486488
self.element_type = kwargs.get("element_type")
487489
if "mmap_enabled" in kwargs:
488490
self._type_params["mmap_enabled"] = kwargs["mmap_enabled"]
@@ -1211,13 +1213,13 @@ def check_schema(schema: CollectionSchema):
12111213
raise SchemaNotReadyException(message=ExceptionsMessage.NoVector)
12121214

12131215

1214-
def infer_default_value_bydata(data: Any):
1216+
def infer_default_value_bydata(data: Any, dtype: DataType = None):
12151217
if data is None:
12161218
return None
12171219
default_data = schema_types.ValueField()
12181220
d_type = DataType.UNKNOWN
12191221
if is_scalar(data):
1220-
d_type = infer_dtype_by_scalar_data(data)
1222+
d_type = infer_dtype_by_scalar_data(data, dtype)
12211223
if d_type is DataType.BOOL:
12221224
default_data.bool_data = data
12231225
elif d_type in (DataType.INT8, DataType.INT16, DataType.INT32):

pymilvus/orm/types.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -71,10 +71,10 @@ def is_numeric_datatype(data_type: DataType):
7171
return is_float_datatype(data_type) or is_integer_datatype(data_type)
7272

7373

74-
def infer_dtype_by_scalar_data(data: Any):
74+
def infer_dtype_by_scalar_data(data: Any, dtype: DataType = None):
7575
if isinstance(data, list):
7676
return DataType.ARRAY
77-
if isinstance(data, np.float32):
77+
if isinstance(data, np.float32) or (is_float(data) and dtype == DataType.FLOAT):
7878
return DataType.FLOAT
7979
if isinstance(data, (float, np.float64, np.double)) or is_float(data):
8080
return DataType.DOUBLE

0 commit comments

Comments
 (0)