1717
1818import numpy as np
1919import pandas as pd
20+ import pyarrow as pa
2021
2122from pymilvus .client .types import (
2223 DataType ,
2829)
2930
3031from .constants import (
32+ ARROW_TYPE_CREATOR ,
3133 DYNAMIC_FIELD_NAME ,
3234 MB ,
3335 NUMPY_TYPE_CREATOR ,
@@ -260,6 +262,33 @@ def _persist_json_rows(self, local_path: str, **kwargs):
260262 logger .info (f"Successfully persist file { file_path } , row count: { len (rows )} " )
261263 return [str (file_path )]
262264
265+ def _deduce_arrow_schema (self ):
266+ arrow_list = []
267+ for field_name , field in self ._fields .items ():
268+ if isinstance (field , FieldSchema ) and (
269+ (field .is_primary and field .auto_id ) or field .is_function_output
270+ ):
271+ continue
272+
273+ if field .dtype .name not in ARROW_TYPE_CREATOR :
274+ self ._throw (f"Unsupported data type: { field .dtype .name } " )
275+
276+ if field .dtype == DataType .ARRAY :
277+ arrow_list .append (
278+ pa .field (field_name , pa .list_ (ARROW_TYPE_CREATOR [field .element_type .name ]))
279+ )
280+ elif field .dtype == DataType .STRUCT :
281+ sub_list = []
282+ for sub_field in field .fields :
283+ sub_list .append (
284+ pa .field (sub_field .name , ARROW_TYPE_CREATOR [sub_field .dtype .name ])
285+ )
286+ arrow_list .append (pa .field (field_name , pa .list_ (pa .struct (sub_list ))))
287+ else :
288+ arrow_list .append (pa .field (field_name , ARROW_TYPE_CREATOR [field .dtype .name ]))
289+
290+ return pa .schema (arrow_list )
291+
263292 def _persist_parquet (self , local_path : str , ** kwargs ):
264293 file_path = Path (local_path + ".parquet" )
265294
@@ -271,16 +300,7 @@ def _persist_parquet(self, local_path: str, **kwargs):
271300 str_arr = []
272301 for val in v :
273302 str_arr .append (json .dumps (val ))
274- data [k ] = pd .Series (str_arr , dtype = None )
275- elif field_schema .dtype in {
276- DataType .BINARY_VECTOR ,
277- DataType .FLOAT_VECTOR ,
278- DataType .INT8_VECTOR ,
279- }:
280- arr = []
281- for val in v :
282- arr .append (np .array (val , dtype = NUMPY_TYPE_CREATOR [field_schema .dtype .name ]))
283- data [k ] = pd .Series (arr )
303+ data [k ] = str_arr
284304 elif field_schema .dtype in {DataType .FLOAT16_VECTOR , DataType .BFLOAT16_VECTOR }:
285305 # special process for float16 vector, the self._buffer stores bytes for
286306 # float16 vector, convert the bytes to uint8 array
@@ -289,25 +309,9 @@ def _persist_parquet(self, local_path: str, **kwargs):
289309 arr .append (
290310 np .frombuffer (val , dtype = NUMPY_TYPE_CREATOR [field_schema .dtype .name ])
291311 )
292- data [k ] = pd .Series (arr )
293- elif field_schema .dtype == DataType .ARRAY :
294- dt = NUMPY_TYPE_CREATOR [field_schema .element_type .name ]
295- arr = []
296- for val in v :
297- arr .append (None if val is None else np .array (val , dtype = dt ))
298- data [k ] = pd .Series (arr )
299- elif field_schema .dtype == DataType .STRUCT :
300- # bulk_import accepts struct array as list[dict],
301- data [k ] = pd .Series (v , dtype = None )
302- elif field_schema .dtype .name in NUMPY_TYPE_CREATOR :
303- dt = NUMPY_TYPE_CREATOR [field_schema .dtype .name ]
304- arr = []
305- for val in v :
306- arr .append (None if val is None else dt .type (val ))
307- data [k ] = np .array (arr )
312+ data [k ] = arr
308313 else :
309- # dtype is null, let pandas deduce the type, might not work
310- data [k ] = pd .Series (v )
314+ data [k ] = v
311315
312316 # calculate a proper row group size
313317 row_group_size_min = 1000
@@ -329,7 +333,10 @@ def _persist_parquet(self, local_path: str, **kwargs):
329333 # write to Parquet file
330334 data_frame = pd .DataFrame (data = data )
331335 data_frame .to_parquet (
332- file_path , row_group_size = row_group_size , engine = "pyarrow"
336+ file_path ,
337+ row_group_size = row_group_size ,
338+ engine = "pyarrow" ,
339+ schema = self ._deduce_arrow_schema (),
333340 ) # don't use fastparquet
334341
335342 logger .info (
0 commit comments