Skip to content

Commit 8e954a5

Browse files
author
Alex Waszkiewicz
committed
Optimize decode_vector with NumPy arrays for faster vector decoding
1 parent 7904bce commit 8e954a5

File tree

1 file changed

+28
-47
lines changed

1 file changed

+28
-47
lines changed

python/cocoindex/engine_value.py

Lines changed: 28 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -365,57 +365,38 @@ def decode(value: Any) -> Any | None:
365365
return lambda value: value
366366

367367
if isinstance(src_type, BasicValueType) and src_type.kind == "Vector":
368-
field_path_str = "".join(field_path)
369-
if not isinstance(dst_type_variant, AnalyzedListType):
370-
raise ValueError(
371-
f"Type mismatch for `{''.join(field_path)}`: "
372-
f"declared `{dst_type_info.core_type}`, a list type expected"
373-
)
374-
expected_dim = (
375-
dst_type_variant.vector_info.dim
376-
if dst_type_variant and dst_type_variant.vector_info
377-
else None
378-
)
379368

380-
vec_elem_decoder = None
381-
scalar_dtype = None
382-
if dst_type_variant and dst_type_info.base_type is np.ndarray:
383-
if is_numpy_number_type(dst_type_variant.elem_type):
384-
scalar_dtype = dst_type_variant.elem_type
385-
else:
386-
# mypy: vector info exists for Vector kind
387-
assert src_type.vector is not None # type: ignore[unreachable]
388-
vec_elem_decoder = make_engine_value_decoder(
389-
field_path + ["[*]"],
390-
src_type.vector.element_type,
391-
analyze_type_info(
392-
dst_type_variant.elem_type if dst_type_variant else Any
393-
),
394-
)
369+
# force numeric vectors to be converted directly to NumPy arrays
370+
scalar_dtype = (
371+
dst_type_variant.elem_type
372+
if dst_type_variant and is_numpy_number_type(dst_type_variant.elem_type)
373+
else None
374+
)
375+
expected_dim = (
376+
dst_type_variant.vector_info.dim
377+
if dst_type_variant and dst_type_variant.vector_info
378+
else None
379+
)
395380

396-
def decode_vector(value: Any) -> Any | None:
397-
if value is None:
398-
if dst_type_info.nullable:
399-
return None
400-
raise ValueError(
401-
f"Received null for non-nullable vector `{field_path_str}`"
402-
)
403-
if not isinstance(value, (np.ndarray, list)):
404-
raise TypeError(
405-
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
406-
)
407-
if expected_dim is not None and len(value) != expected_dim:
408-
raise ValueError(
409-
f"Vector dimension mismatch for `{field_path_str}`: "
410-
f"expected {expected_dim}, got {len(value)}"
411-
)
381+
def decode_vector(value: Any) -> Any | None:
382+
if value is None:
383+
return None if dst_type_info.nullable else np.zeros(expected_dim, dtype=scalar_dtype)
384+
385+
if not isinstance(value, (list, np.ndarray)):
386+
raise TypeError(f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}")
387+
388+
389+
arr = np.array(value, dtype=scalar_dtype)
390+
391+
if expected_dim is not None and arr.shape[0] != expected_dim:
392+
raise ValueError(
393+
f"Vector dimension mismatch for `{''.join(field_path)}`: "
394+
f"expected {expected_dim}, got {arr.shape[0]}"
395+
)
396+
return arr
412397

413-
if vec_elem_decoder is not None: # for Non-NDArray vector
414-
return [vec_elem_decoder(v) for v in value]
415-
else: # for NDArray vector
416-
return np.array(value, dtype=scalar_dtype)
398+
return decode_vector
417399

418-
return decode_vector
419400

420401
if isinstance(dst_type_variant, AnalyzedBasicType):
421402
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):

0 commit comments

Comments
 (0)