Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion python/cocoindex/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ def encode_engine_value(value: Any) -> Any:
]
if is_namedtuple_type(type(value)):
return [encode_engine_value(getattr(value, name)) for name in value._fields]
if isinstance(value, np.number):
return value.item()
if isinstance(value, np.ndarray):
return value
if isinstance(value, (list, tuple)):
Expand Down Expand Up @@ -86,6 +88,20 @@ def make_engine_value_decoder(
field_path, src_type["fields"], dst_type_info.struct_type
)

if dst_type_info.np_number_type is not None:
numpy_type = dst_type_info.np_number_type

def decode_numpy_scalar(value: Any) -> Any | None:
if value is None:
if dst_type_info.nullable:
return None
raise ValueError(
f"Received null for non-nullable scalar `{''.join(field_path)}`"
)
return numpy_type(value)

return decode_numpy_scalar

if src_type_kind in TABLE_TYPES:
field_path.append("[*]")
elem_type_info = analyze_type_info(dst_type_info.elem_type)
Expand Down Expand Up @@ -127,7 +143,8 @@ def decode(value: Any) -> Any | None:
return lambda value: uuid.UUID(bytes=value)

if src_type_kind == "Vector":
dtype_info = DtypeRegistry.get_by_dtype(dst_type_info.np_number_type)
elem_type_info = analyze_type_info(dst_type_info.elem_type)
dtype_info = DtypeRegistry.get_by_dtype(elem_type_info.np_number_type)

def decode_vector(value: Any) -> Any | None:
if value is None:
Expand Down
26 changes: 26 additions & 0 deletions python/cocoindex/tests/test_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,19 @@ def test_encode_engine_value_date_time_types():
assert encode_engine_value(dt) == dt


def test_encode_scalar_numpy_values():
"""Test encoding scalar NumPy values to engine-compatible values."""
test_cases = [
(np.int64(42), 42),
(np.float32(3.14), pytest.approx(3.14)),
(np.float64(2.718), pytest.approx(2.718)),
]
for np_value, expected in test_cases:
encoded = encode_engine_value(np_value)
assert encoded == expected
assert isinstance(encoded, (int, float))


def test_encode_engine_value_struct():
order = Order(order_id="O123", name="mixed nuts", price=25.0)
assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"]
Expand Down Expand Up @@ -197,6 +210,19 @@ def test_make_engine_value_decoder_basic_types():
assert decoder(value) == value


def test_decode_scalar_numpy_values():
test_cases = [
({"kind": "Int64"}, np.int64, 42, np.int64(42)),
({"kind": "Float32"}, np.float32, 3.14, np.float32(3.14)),
({"kind": "Float64"}, np.float64, 2.718, np.float64(2.718)),
]
for src_type, dst_type, input_value, expected in test_cases:
decoder = make_engine_value_decoder(["field"], src_type, dst_type)
result = decoder(input_value)
assert isinstance(result, dst_type)
assert result == expected


@pytest.mark.parametrize(
"data_type, engine_val, expected",
[
Expand Down
40 changes: 35 additions & 5 deletions python/cocoindex/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def test_ndarray_float32_no_dim():
elem_type=Float32,
key_type=None,
struct_type=None,
np_number_type=np.float32,
np_number_type=None,
attrs=None,
nullable=False,
)
Expand All @@ -63,7 +63,7 @@ def test_vector_float32_no_dim():
elem_type=Float32,
key_type=None,
struct_type=None,
np_number_type=np.float32,
np_number_type=None,
attrs=None,
nullable=False,
)
Expand All @@ -78,7 +78,7 @@ def test_ndarray_float64_with_dim():
elem_type=Float64,
key_type=None,
struct_type=None,
np_number_type=np.float64,
np_number_type=None,
attrs=None,
nullable=False,
)
Expand All @@ -93,7 +93,7 @@ def test_vector_float32_with_dim():
elem_type=Float32,
key_type=None,
struct_type=None,
np_number_type=np.float32,
np_number_type=None,
attrs=None,
nullable=False,
)
Expand All @@ -117,12 +117,29 @@ def test_nullable_ndarray():
elem_type=Float32,
key_type=None,
struct_type=None,
np_number_type=np.float32,
np_number_type=None,
attrs=None,
nullable=True,
)


def test_scalar_numpy_types():
for np_type, expected_kind in [
(np.int64, "Int64"),
(np.float32, "Float32"),
(np.float64, "Float64"),
]:
type_info = analyze_type_info(np_type)
assert (
type_info.kind == expected_kind
), f"Expected {expected_kind} for {np_type}, got {type_info.kind}"
assert (
type_info.np_number_type == np_type
), f"Expected {np_type}, got {type_info.np_number_type}"
assert type_info.elem_type is None
assert type_info.vector_info is None


def test_vector_str():
typ = Vector[str]
result = analyze_type_info(typ)
Expand Down Expand Up @@ -487,6 +504,19 @@ def test_encode_enriched_type_nullable():
assert result["nullable"] is True


def test_encode_scalar_numpy_types_schema():
for np_type, expected_kind in [
(np.int64, "Int64"),
(np.float32, "Float32"),
(np.float64, "Float64"),
]:
schema = encode_enriched_type(np_type)
assert (
schema["type"]["kind"] == expected_kind
), f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
assert not schema.get("nullable", False)


def test_invalid_struct_kind():
typ = Annotated[SimpleDataclass, TypeKind("Vector")]
with pytest.raises(ValueError, match="Unexpected type kind for struct: Vector"):
Expand Down
17 changes: 13 additions & 4 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,13 @@ def get_by_dtype(cls, dtype: Any) -> DtypeInfo | None:
)
return cls._mappings.get(dtype)

@staticmethod
def get_by_kind(kind: str) -> DtypeInfo | None:
return next(
(info for info in DtypeRegistry._mappings.values() if info.kind == kind),
None,
)

@staticmethod
def supported_dtypes() -> KeysView[type]:
"""Get a list of supported NumPy dtypes."""
Expand Down Expand Up @@ -193,6 +200,7 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
attrs: dict[str, Any] | None = None
vector_info: VectorInfo | None = None
kind: str | None = None
np_number_type: type | None = None
for attr in annotations:
if isinstance(attr, TypeAttr):
if attrs is None:
Expand All @@ -202,11 +210,12 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
vector_info = attr
elif isinstance(attr, TypeKind):
kind = attr.kind
if dtype_info := DtypeRegistry.get_by_kind(attr.kind):
np_number_type = dtype_info.numpy_dtype

struct_type: type | None = None
elem_type: ElementType | None = None
key_type: type | None = None
np_number_type: type | None = None
if _is_struct_type(t):
struct_type = t

Expand Down Expand Up @@ -240,11 +249,11 @@ def analyze_type_info(t: Any) -> AnalyzedTypeInfo:
if not dtype_args:
raise ValueError("Invalid dtype specification for NDArray")

np_number_type = dtype_args[0]
dtype_info = DtypeRegistry.get_by_dtype(np_number_type)
numpy_dtype = dtype_args[0]
dtype_info = DtypeRegistry.get_by_dtype(numpy_dtype)
if dtype_info is None:
raise ValueError(
f"Unsupported numpy dtype for NDArray: {np_number_type}. "
f"Unsupported numpy dtype for NDArray: {numpy_dtype}. "
f"Supported dtypes: {DtypeRegistry.supported_dtypes()}"
)
elem_type = dtype_info.annotated_type
Expand Down
Loading