Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 75 additions & 2 deletions python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def __class_getitem__(self, params):
return Annotated[list[dtype], vector_info]


TABLE_TYPES: tuple[str, str] = ("KTable", "LTable")
TABLE_TYPES: tuple[str, str, str] = ("UTable", "KTable", "LTable")
KEY_FIELD_NAME: str = "_key"


Expand Down Expand Up @@ -516,6 +516,13 @@ class VectorTypeSchema:
element_type: "BasicValueType"
dimension: int | None

def __str__(self) -> str:
dimension_str = f", {self.dimension}" if self.dimension is not None else ""
return f"Vector[{self.element_type}{dimension_str}]"

def __repr__(self) -> str:
return self.__str__()

@staticmethod
def decode(obj: dict[str, Any]) -> "VectorTypeSchema":
return VectorTypeSchema(
Expand All @@ -534,6 +541,13 @@ def encode(self) -> dict[str, Any]:
class UnionTypeSchema:
variants: list["BasicValueType"]

def __str__(self) -> str:
types_str = " | ".join(str(t) for t in self.variants)
return f"Union[{types_str}]"

def __repr__(self) -> str:
return self.__str__()

@staticmethod
def decode(obj: dict[str, Any]) -> "UnionTypeSchema":
return UnionTypeSchema(
Expand Down Expand Up @@ -573,6 +587,19 @@ class BasicValueType:
vector: VectorTypeSchema | None = None
union: UnionTypeSchema | None = None

def __str__(self) -> str:
if self.kind == "Vector" and self.vector is not None:
dimension_str = f", {self.vector.dimension}" if self.vector.dimension is not None else ""
return f"Vector[{self.vector.element_type}{dimension_str}]"
elif self.kind == "Union" and self.union is not None:
types_str = " | ".join(str(t) for t in self.union.variants)
return f"Union[{types_str}]"
else:
return self.kind

def __repr__(self) -> str:
return self.__str__()

@staticmethod
def decode(obj: dict[str, Any]) -> "BasicValueType":
kind = obj["kind"]
Expand Down Expand Up @@ -603,6 +630,18 @@ class EnrichedValueType:
nullable: bool = False
attrs: dict[str, Any] | None = None

def __str__(self) -> str:
result = str(self.type)
if self.nullable:
result += "?"
if self.attrs:
attrs_str = ", ".join(f"{k}: {v}" for k, v in self.attrs.items())
result += f" [{attrs_str}]"
return result

def __repr__(self) -> str:
return self.__str__()

@staticmethod
def decode(obj: dict[str, Any]) -> "EnrichedValueType":
return EnrichedValueType(
Expand All @@ -625,6 +664,12 @@ class FieldSchema:
name: str
value_type: EnrichedValueType

def __str__(self) -> str:
return f"{self.name}: {self.value_type}"

def __repr__(self) -> str:
return self.__str__()

@staticmethod
def decode(obj: dict[str, Any]) -> "FieldSchema":
return FieldSchema(name=obj["name"], value_type=EnrichedValueType.decode(obj))
Expand All @@ -640,6 +685,13 @@ class StructSchema:
fields: list[FieldSchema]
description: str | None = None

def __str__(self) -> str:
fields_str = ", ".join(str(field) for field in self.fields)
return f"Struct({fields_str})"

def __repr__(self) -> str:
return self.__str__()

@classmethod
def decode(cls, obj: dict[str, Any]) -> Self:
return cls(
Expand All @@ -658,6 +710,13 @@ def encode(self) -> dict[str, Any]:
class StructType(StructSchema):
kind: Literal["Struct"] = "Struct"

def __str__(self) -> str:
# Use the parent's __str__ method for consistency
return super().__str__()

def __repr__(self) -> str:
return self.__str__()

def encode(self) -> dict[str, Any]:
result = super().encode()
result["kind"] = self.kind
Expand All @@ -666,10 +725,24 @@ def encode(self) -> dict[str, Any]:

@dataclasses.dataclass
class TableType:
kind: Literal["KTable", "LTable"]
kind: Literal["UTable", "KTable", "LTable"]
row: StructSchema
num_key_parts: int | None = None # Only for KTable

def __str__(self) -> str:
if self.kind == "KTable":
num_parts = self.num_key_parts if self.num_key_parts is not None else 1
table_kind = f"KTable({num_parts})"
elif self.kind == "LTable":
table_kind = "LTable"
else: # UTable
table_kind = "Table"

return f"{table_kind}({self.row})"

def __repr__(self) -> str:
return self.__str__()

@staticmethod
def decode(obj: dict[str, Any]) -> "TableType":
row_obj = obj["row"]
Expand Down
Loading