Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 79 additions & 21 deletions python/cocoindex/tests/test_typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ class SimpleDataclass:
value: int


@dataclasses.dataclass
class SimpleDataclassWithDescription:
"""This is a simple dataclass with a description."""

name: str
value: int


class SimpleNamedTuple(NamedTuple):
name: str
value: int
Expand Down Expand Up @@ -357,45 +365,96 @@ def test_encode_enriched_type_none() -> None:
assert result is None


def test_encode_enriched_type_struct() -> None:
def test_encode_enriched_dataclass() -> None:
typ = SimpleDataclass
result = encode_enriched_type(typ)
assert result["type"]["kind"] == "Struct"
assert len(result["type"]["fields"]) == 2
assert result["type"]["fields"][0]["name"] == "name"
assert result["type"]["fields"][0]["type"]["kind"] == "Str"
assert result["type"]["fields"][1]["name"] == "value"
assert result["type"]["fields"][1]["type"]["kind"] == "Int64"
assert result == {
"type": {
"kind": "Struct",
"description": "SimpleDataclass(name: str, value: int)",
"fields": [
{"name": "name", "type": {"kind": "Str"}},
{"name": "value", "type": {"kind": "Int64"}},
],
},
}


def test_encode_enriched_dataclass_with_description() -> None:
typ = SimpleDataclassWithDescription
result = encode_enriched_type(typ)
assert result == {
"type": {
"kind": "Struct",
"description": "This is a simple dataclass with a description.",
"fields": [
{"name": "name", "type": {"kind": "Str"}},
{"name": "value", "type": {"kind": "Int64"}},
],
},
}


def test_encode_named_tuple() -> None:
typ = SimpleNamedTuple
result = encode_enriched_type(typ)
assert result == {
"type": {
"kind": "Struct",
"description": "SimpleNamedTuple(name, value)",
"fields": [
{"name": "name", "type": {"kind": "Str"}},
{"name": "value", "type": {"kind": "Int64"}},
],
},
}


def test_encode_enriched_type_vector() -> None:
typ = NDArray[np.float32]
result = encode_enriched_type(typ)
assert result["type"]["kind"] == "Vector"
assert result["type"]["element_type"]["kind"] == "Float32"
assert result["type"]["dimension"] is None
assert result == {
"type": {
"kind": "Vector",
"element_type": {"kind": "Float32"},
"dimension": None,
},
}


def test_encode_enriched_type_ltable() -> None:
typ = list[SimpleDataclass]
result = encode_enriched_type(typ)
assert result["type"]["kind"] == "LTable"
assert "fields" in result["type"]["row"]
assert len(result["type"]["row"]["fields"]) == 2
assert result == {
"type": {
"kind": "LTable",
"row": {
"description": "SimpleDataclass(name: str, value: int)",
"fields": [
{"name": "name", "type": {"kind": "Str"}},
{"name": "value", "type": {"kind": "Int64"}},
],
},
},
}


def test_encode_enriched_type_with_attrs() -> None:
typ = Annotated[str, TypeAttr("key", "value")]
result = encode_enriched_type(typ)
assert result["type"]["kind"] == "Str"
assert result["attrs"] == {"key": "value"}
assert result == {
"type": {"kind": "Str"},
"attrs": {"key": "value"},
}


def test_encode_enriched_type_nullable() -> None:
typ = str | None
result = encode_enriched_type(typ)
assert result["type"]["kind"] == "Str"
assert result["nullable"] is True
assert result == {
"type": {"kind": "Str"},
"nullable": True,
}


def test_encode_scalar_numpy_types_schema() -> None:
Expand All @@ -405,10 +464,9 @@ def test_encode_scalar_numpy_types_schema() -> None:
(np.float64, "Float64"),
]:
schema = encode_enriched_type(np_type)
assert schema["type"]["kind"] == expected_kind, (
f"Expected {expected_kind} for {np_type}, got {schema['type']['kind']}"
)
assert not schema.get("nullable", False)
assert schema == {
"type": {"kind": expected_kind},
}, f"Expected kind {expected_kind} for {np_type}, got {schema}"


def test_annotated_struct_with_type_kind() -> None:
Expand Down
2 changes: 1 addition & 1 deletion python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,7 +444,7 @@ def add_fields_from_struct(struct_info: AnalyzedStructType) -> None:
add_fields_from_struct(struct_info)

result["fields"] = fields
if doc := inspect.getdoc(struct_info):
if doc := inspect.getdoc(struct_info.struct_type):
result["description"] = doc
return result, num_key_parts

Expand Down
Loading