Skip to content

Commit 8170c6c

Browse files
yhmoXuanYang-cn
andauthored
Fix a bug of bulkwriter to support all-empty struct list (#3192)
Signed-off-by: yhmo <yihua.mo@zilliz.com> Co-authored-by: XuanYang-cn <xuan.yang@zilliz.com>
1 parent 27d1440 commit 8170c6c

File tree

11 files changed

+91
-40
lines changed

11 files changed

+91
-40
lines changed

examples/bulk_import/bulk_writer_all_types.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import pytz
33
import time
44
import numpy as np
5+
from pathlib import Path
56
from typing import List
67

78
from pymilvus import (
@@ -16,6 +17,9 @@
1617
get_import_progress,
1718
)
1819

20+
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert/"
21+
Path(LOCAL_FILES_PATH).mkdir(exist_ok=True)
22+
1923
# minio
2024
MINIO_ADDRESS = "0.0.0.0:9000"
2125
MINIO_SECRET_KEY = "minioadmin"
@@ -216,7 +220,7 @@ def remote_writer(schema: CollectionSchema, file_type: BulkFileType):
216220
with RemoteBulkWriter(
217221
schema=schema,
218222
remote_path="bulk_data",
219-
local_path="/tmp/PARQUET",
223+
local_path=LOCAL_FILES_PATH,
220224
connect_param=RemoteBulkWriter.S3ConnectParam(
221225
endpoint=MINIO_ADDRESS,
222226
access_key=MINIO_ACCESS_KEY,
@@ -269,7 +273,7 @@ def local_writer(schema: CollectionSchema, file_type: BulkFileType):
269273
print(f"\n===================== local writer ({file_type.name}) ====================")
270274
writer = LocalBulkWriter(
271275
schema=schema,
272-
local_path="./" + file_type.name,
276+
local_path=LOCAL_FILES_PATH,
273277
chunk_size=16 * 1024 * 1024,
274278
file_type=file_type
275279
)

examples/bulk_import/example_bulkinsert_json.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from minio import Minio
88
from minio.error import S3Error
9+
from pathlib import Path
910

1011
from pymilvus import (
1112
DataType,
@@ -18,7 +19,8 @@
1819
)
1920

2021
# Local path to generate JSON files
21-
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert"
22+
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert/"
23+
Path(LOCAL_FILES_PATH).mkdir(exist_ok=True)
2224

2325
# Milvus service address
2426
_HOST = '127.0.0.1'

examples/bulk_import/example_bulkinsert_parquet.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import time
44
import os
5+
from pathlib import Path
56
from typing import List
67

78
from minio import Minio
@@ -24,6 +25,7 @@
2425

2526
# Local path to generate files
2627
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert/"
28+
Path(LOCAL_FILES_PATH).mkdir(exist_ok=True)
2729

2830
# Milvus service address
2931
_HOST = '127.0.0.1'
@@ -324,6 +326,7 @@ def verify(data):
324326

325327
# Extract IDs from the data
326328
ids = [int(data[_ID_FIELD_NAME][k]) for k in indices]
329+
ids = [int(val) if isinstance(val, np.int64) else val for val in ids]
327330
results = client.query(collection_name=_COLLECTION_NAME,
328331
filter=f"{_ID_FIELD_NAME} in {ids}",
329332
output_fields=["*"])

examples/orm_deprecated/bulk_import/example_bulkinsert_csv.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from minio import Minio
88
from minio.error import S3Error
9+
from pathlib import Path
910

1011
from pymilvus import (
1112
connections,
@@ -16,7 +17,8 @@
1617
)
1718

1819

19-
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert"
20+
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert/"
21+
Path(LOCAL_FILES_PATH).mkdir(exist_ok=True)
2022

2123
# Milvus service address
2224
_HOST = '127.0.0.1'

examples/orm_deprecated/bulk_import/example_bulkinsert_numpy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from minio import Minio
88
from minio.error import S3Error
9+
from pathlib import Path
910

1011
from pymilvus import (
1112
connections,
@@ -37,6 +38,7 @@
3738

3839
# Local path to generate Numpy files
3940
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert/"
41+
Path(LOCAL_FILES_PATH).mkdir(exist_ok=True)
4042

4143
# Milvus service address
4244
_HOST = '127.0.0.1'

examples/orm_deprecated/bulk_import/example_bulkinsert_withfunction.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
import time
55
import os
66

7+
from pathlib import Path
8+
79
from pymilvus import (
810
connections,
911
FieldSchema, CollectionSchema, DataType,
@@ -14,7 +16,8 @@
1416
)
1517

1618

17-
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert"
19+
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert/"
20+
Path(LOCAL_FILES_PATH).mkdir(exist_ok=True)
1821

1922
# Milvus service address
2023
_HOST = '127.0.0.1'

examples/orm_deprecated/bulk_import/example_bulkwriter.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,13 @@
1414
import logging
1515
import threading
1616
import time
17+
from pathlib import Path
1718
from typing import List
19+
1820
import numpy as np
1921
import pandas as pd
2022

21-
from examples.orm_deprecated.bulk_import.data_gengerator import *
23+
from examples.bulk_import.data_gengerator import *
2224

2325
logging.basicConfig(level=logging.INFO)
2426

@@ -38,6 +40,9 @@
3840
get_import_progress,
3941
)
4042

43+
LOCAL_FILES_PATH = "/tmp/milvus_bulkinsert/"
44+
Path(LOCAL_FILES_PATH).mkdir(exist_ok=True)
45+
4146
# minio
4247
MINIO_ADDRESS = "0.0.0.0:9000"
4348
MINIO_SECRET_KEY = "minioadmin"
@@ -121,7 +126,7 @@ def local_writer_simple(schema: CollectionSchema, file_type: BulkFileType):
121126
print(f"\n===================== local writer ({file_type.name}) ====================")
122127
with LocalBulkWriter(
123128
schema=schema,
124-
local_path="/tmp/bulk_writer",
129+
local_path=LOCAL_FILES_PATH,
125130
segment_size=128*1024*1024,
126131
file_type=file_type,
127132
) as local_writer:
@@ -181,7 +186,7 @@ def _append_row(writer: LocalBulkWriter, begin: int, end: int):
181186

182187
local_writer = LocalBulkWriter(
183188
schema=schema,
184-
local_path="/tmp/bulk_writer",
189+
local_path=LOCAL_FILES_PATH,
185190
segment_size=128 * 1024 * 1024, # 128MB
186191
file_type=BulkFileType.JSON,
187192
)

examples/orm_deprecated/bulk_import/example_bulkwriter_with_nullable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import time
55
from typing import List
66

7-
from examples.orm_deprecated.bulk_import.data_gengerator import *
7+
from examples.bulk_import.data_gengerator import *
88

99
logging.basicConfig(level=logging.INFO)
1010

pymilvus/bulk_writer/buffer.py

Lines changed: 36 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import numpy as np
1919
import pandas as pd
20+
import pyarrow as pa
2021

2122
from pymilvus.client.types import (
2223
DataType,
@@ -28,6 +29,7 @@
2829
)
2930

3031
from .constants import (
32+
ARROW_TYPE_CREATOR,
3133
DYNAMIC_FIELD_NAME,
3234
MB,
3335
NUMPY_TYPE_CREATOR,
@@ -260,6 +262,33 @@ def _persist_json_rows(self, local_path: str, **kwargs):
260262
logger.info(f"Successfully persist file {file_path}, row count: {len(rows)}")
261263
return [str(file_path)]
262264

265+
def _deduce_arrow_schema(self):
266+
arrow_list = []
267+
for field_name, field in self._fields.items():
268+
if isinstance(field, FieldSchema) and (
269+
(field.is_primary and field.auto_id) or field.is_function_output
270+
):
271+
continue
272+
273+
if field.dtype.name not in ARROW_TYPE_CREATOR:
274+
self._throw(f"Unsupported data type: {field.dtype.name}")
275+
276+
if field.dtype == DataType.ARRAY:
277+
arrow_list.append(
278+
pa.field(field_name, pa.list_(ARROW_TYPE_CREATOR[field.element_type.name]))
279+
)
280+
elif field.dtype == DataType.STRUCT:
281+
sub_list = []
282+
for sub_field in field.fields:
283+
sub_list.append(
284+
pa.field(sub_field.name, ARROW_TYPE_CREATOR[sub_field.dtype.name])
285+
)
286+
arrow_list.append(pa.field(field_name, pa.list_(pa.struct(sub_list))))
287+
else:
288+
arrow_list.append(pa.field(field_name, ARROW_TYPE_CREATOR[field.dtype.name]))
289+
290+
return pa.schema(arrow_list)
291+
263292
def _persist_parquet(self, local_path: str, **kwargs):
264293
file_path = Path(local_path + ".parquet")
265294

@@ -271,16 +300,7 @@ def _persist_parquet(self, local_path: str, **kwargs):
271300
str_arr = []
272301
for val in v:
273302
str_arr.append(json.dumps(val))
274-
data[k] = pd.Series(str_arr, dtype=None)
275-
elif field_schema.dtype in {
276-
DataType.BINARY_VECTOR,
277-
DataType.FLOAT_VECTOR,
278-
DataType.INT8_VECTOR,
279-
}:
280-
arr = []
281-
for val in v:
282-
arr.append(np.array(val, dtype=NUMPY_TYPE_CREATOR[field_schema.dtype.name]))
283-
data[k] = pd.Series(arr)
303+
data[k] = str_arr
284304
elif field_schema.dtype in {DataType.FLOAT16_VECTOR, DataType.BFLOAT16_VECTOR}:
285305
# special process for float16 vector, the self._buffer stores bytes for
286306
# float16 vector, convert the bytes to uint8 array
@@ -289,25 +309,9 @@ def _persist_parquet(self, local_path: str, **kwargs):
289309
arr.append(
290310
np.frombuffer(val, dtype=NUMPY_TYPE_CREATOR[field_schema.dtype.name])
291311
)
292-
data[k] = pd.Series(arr)
293-
elif field_schema.dtype == DataType.ARRAY:
294-
dt = NUMPY_TYPE_CREATOR[field_schema.element_type.name]
295-
arr = []
296-
for val in v:
297-
arr.append(None if val is None else np.array(val, dtype=dt))
298-
data[k] = pd.Series(arr)
299-
elif field_schema.dtype == DataType.STRUCT:
300-
# bulk_import accepts struct array as list[dict],
301-
data[k] = pd.Series(v, dtype=None)
302-
elif field_schema.dtype.name in NUMPY_TYPE_CREATOR:
303-
dt = NUMPY_TYPE_CREATOR[field_schema.dtype.name]
304-
arr = []
305-
for val in v:
306-
arr.append(None if val is None else dt.type(val))
307-
data[k] = np.array(arr)
312+
data[k] = arr
308313
else:
309-
# dtype is null, let pandas deduce the type, might not work
310-
data[k] = pd.Series(v)
314+
data[k] = v
311315

312316
# calculate a proper row group size
313317
row_group_size_min = 1000
@@ -329,7 +333,10 @@ def _persist_parquet(self, local_path: str, **kwargs):
329333
# write to Parquet file
330334
data_frame = pd.DataFrame(data=data)
331335
data_frame.to_parquet(
332-
file_path, row_group_size=row_group_size, engine="pyarrow"
336+
file_path,
337+
row_group_size=row_group_size,
338+
engine="pyarrow",
339+
schema=self._deduce_arrow_schema(),
333340
) # don't use fastparquet
334341

335342
logger.info(

pymilvus/bulk_writer/constants.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from enum import Enum, IntEnum
1414

1515
import numpy as np
16+
import pyarrow as pa
1617

1718
from pymilvus.client.types import (
1819
DataType,
@@ -81,12 +82,34 @@
8182
DataType.BINARY_VECTOR.name: np.dtype("uint8"),
8283
DataType.FLOAT16_VECTOR.name: np.dtype("uint8"),
8384
DataType.BFLOAT16_VECTOR.name: np.dtype("uint8"),
84-
DataType.SPARSE_FLOAT_VECTOR: None,
85+
DataType.SPARSE_FLOAT_VECTOR.name: None,
8586
DataType.INT8_VECTOR.name: np.dtype("int8"),
8687
DataType.ARRAY.name: None,
8788
DataType.STRUCT.name: None,
8889
}
8990

91+
ARROW_TYPE_CREATOR = {
92+
DataType.BOOL.name: pa.bool_(),
93+
DataType.INT8.name: pa.int8(),
94+
DataType.INT16.name: pa.int16(),
95+
DataType.INT32.name: pa.int32(),
96+
DataType.INT64.name: pa.int64(),
97+
DataType.FLOAT.name: pa.float32(),
98+
DataType.DOUBLE.name: pa.float64(),
99+
DataType.VARCHAR.name: pa.string(),
100+
DataType.JSON.name: pa.string(), # in numpy/parquet file, json objects are stored as string
101+
DataType.TIMESTAMPTZ.name: pa.string(),
102+
DataType.GEOMETRY.name: pa.string(),
103+
DataType.FLOAT_VECTOR.name: pa.list_(pa.float32()),
104+
DataType.BINARY_VECTOR.name: pa.list_(pa.uint8()),
105+
DataType.FLOAT16_VECTOR.name: pa.list_(pa.uint8()),
106+
DataType.BFLOAT16_VECTOR.name: pa.list_(pa.uint8()),
107+
DataType.SPARSE_FLOAT_VECTOR.name: pa.string(), # in numpy/parquet file, sparse vectors are stored as string
108+
DataType.INT8_VECTOR.name: pa.list_(pa.int8()),
109+
DataType.ARRAY.name: None,
110+
DataType.STRUCT.name: None,
111+
}
112+
90113

91114
class BulkFileType(IntEnum):
92115
NUMPY = 1

0 commit comments

Comments
 (0)