@@ -365,57 +365,38 @@ def decode(value: Any) -> Any | None:
365365 return lambda value : value
366366
367367 if isinstance (src_type , BasicValueType ) and src_type .kind == "Vector" :
368- field_path_str = "" .join (field_path )
369- if not isinstance (dst_type_variant , AnalyzedListType ):
370- raise ValueError (
371- f"Type mismatch for `{ '' .join (field_path )} `: "
372- f"declared `{ dst_type_info .core_type } `, a list type expected"
373- )
374- expected_dim = (
375- dst_type_variant .vector_info .dim
376- if dst_type_variant and dst_type_variant .vector_info
377- else None
378- )
379368
380- vec_elem_decoder = None
381- scalar_dtype = None
382- if dst_type_variant and dst_type_info .base_type is np .ndarray :
383- if is_numpy_number_type (dst_type_variant .elem_type ):
384- scalar_dtype = dst_type_variant .elem_type
385- else :
386- # mypy: vector info exists for Vector kind
387- assert src_type .vector is not None # type: ignore[unreachable]
388- vec_elem_decoder = make_engine_value_decoder (
389- field_path + ["[*]" ],
390- src_type .vector .element_type ,
391- analyze_type_info (
392- dst_type_variant .elem_type if dst_type_variant else Any
393- ),
394- )
369+ # force numeric vectors to be converted directly to NumPy arrays
370+ scalar_dtype = (
371+ dst_type_variant .elem_type
372+ if dst_type_variant and is_numpy_number_type (dst_type_variant .elem_type )
373+ else None
374+ )
375+ expected_dim = (
376+ dst_type_variant .vector_info .dim
377+ if dst_type_variant and dst_type_variant .vector_info
378+ else None
379+ )
395380
396- def decode_vector (value : Any ) -> Any | None :
397- if value is None :
398- if dst_type_info .nullable :
399- return None
400- raise ValueError (
401- f"Received null for non-nullable vector `{ field_path_str } `"
402- )
403- if not isinstance ( value , ( np . ndarray , list )):
404- raise TypeError (
405- f"Expected NDArray or list for vector ` { field_path_str } `, got { type ( value ) } "
406- )
407- if expected_dim is not None and len ( value ) != expected_dim :
408- raise ValueError (
409- f"Vector dimension mismatch for ` { field_path_str } `: "
410- f"expected { expected_dim } , got { len ( value ) } "
411- )
381+ def decode_vector (value : Any ) -> Any | None :
382+ if value is None :
383+ return None if dst_type_info .nullable else np . zeros ( expected_dim , dtype = scalar_dtype )
384+
385+ if not isinstance ( value , ( list , np . ndarray )):
386+ raise TypeError ( f"Expected NDArray or list for vector `{ '' . join ( field_path ) } `, got { type ( value ) } " )
387+
388+
389+ arr = np . array ( value , dtype = scalar_dtype )
390+
391+ if expected_dim is not None and arr . shape [ 0 ] != expected_dim :
392+ raise ValueError (
393+ f"Vector dimension mismatch for ` { '' . join ( field_path ) } `: "
394+ f"expected { expected_dim } , got { arr . shape [ 0 ] } "
395+ )
396+ return arr
412397
413- if vec_elem_decoder is not None : # for Non-NDArray vector
414- return [vec_elem_decoder (v ) for v in value ]
415- else : # for NDArray vector
416- return np .array (value , dtype = scalar_dtype )
398+ return decode_vector
417399
418- return decode_vector
419400
420401 if isinstance (dst_type_variant , AnalyzedBasicType ):
421402 if not _is_type_kind_convertible_to (src_type_kind , dst_type_variant .kind ):
0 commit comments