Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 28 additions & 47 deletions python/cocoindex/engine_value.py
Original file line number Diff line number Diff line change
Expand Up @@ -365,57 +365,38 @@ def decode(value: Any) -> Any | None:
return lambda value: value

if isinstance(src_type, BasicValueType) and src_type.kind == "Vector":
field_path_str = "".join(field_path)
if not isinstance(dst_type_variant, AnalyzedListType):
raise ValueError(
f"Type mismatch for `{''.join(field_path)}`: "
f"declared `{dst_type_info.core_type}`, a list type expected"
)
expected_dim = (
dst_type_variant.vector_info.dim
if dst_type_variant and dst_type_variant.vector_info
else None
)

vec_elem_decoder = None
scalar_dtype = None
if dst_type_variant and dst_type_info.base_type is np.ndarray:
if is_numpy_number_type(dst_type_variant.elem_type):
scalar_dtype = dst_type_variant.elem_type
else:
# mypy: vector info exists for Vector kind
assert src_type.vector is not None # type: ignore[unreachable]
vec_elem_decoder = make_engine_value_decoder(
field_path + ["[*]"],
src_type.vector.element_type,
analyze_type_info(
dst_type_variant.elem_type if dst_type_variant else Any
),
)
# force numeric vectors to be converted directly to NumPy arrays
scalar_dtype = (
dst_type_variant.elem_type
if dst_type_variant and is_numpy_number_type(dst_type_variant.elem_type)
else None
)
expected_dim = (
dst_type_variant.vector_info.dim
if dst_type_variant and dst_type_variant.vector_info
else None
)

def decode_vector(value: Any) -> Any | None:
if value is None:
if dst_type_info.nullable:
return None
raise ValueError(
f"Received null for non-nullable vector `{field_path_str}`"
)
if not isinstance(value, (np.ndarray, list)):
raise TypeError(
f"Expected NDArray or list for vector `{field_path_str}`, got {type(value)}"
)
if expected_dim is not None and len(value) != expected_dim:
raise ValueError(
f"Vector dimension mismatch for `{field_path_str}`: "
f"expected {expected_dim}, got {len(value)}"
)
def decode_vector(value: Any) -> Any | None:
if value is None:
return None if dst_type_info.nullable else np.zeros(expected_dim, dtype=scalar_dtype)

if not isinstance(value, (list, np.ndarray)):
raise TypeError(f"Expected NDArray or list for vector `{''.join(field_path)}`, got {type(value)}")


arr = np.array(value, dtype=scalar_dtype)

if expected_dim is not None and arr.shape[0] != expected_dim:
raise ValueError(
f"Vector dimension mismatch for `{''.join(field_path)}`: "
f"expected {expected_dim}, got {arr.shape[0]}"
)
return arr

if vec_elem_decoder is not None: # for Non-NDArray vector
return [vec_elem_decoder(v) for v in value]
else: # for NDArray vector
return np.array(value, dtype=scalar_dtype)
return decode_vector

return decode_vector

if isinstance(dst_type_variant, AnalyzedBasicType):
if not _is_type_kind_convertible_to(src_type_kind, dst_type_variant.kind):
Expand Down
Loading