Skip to content

Commit 3319cae

Browse files
committed
feat: implemented _str_ and _repr_ methods for python schema classes
1 parent 1eac6b1 commit 3319cae

File tree

1 file changed

+75
-2
lines changed

1 file changed

+75
-2
lines changed

python/cocoindex/typing.py

Lines changed: 75 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def __class_getitem__(self, params):
102102
return Annotated[list[dtype], vector_info]
103103

104104

105-
TABLE_TYPES: tuple[str, str] = ("KTable", "LTable")
105+
TABLE_TYPES: tuple[str, str, str] = ("UTable", "KTable", "LTable")
106106
KEY_FIELD_NAME: str = "_key"
107107

108108

@@ -516,6 +516,13 @@ class VectorTypeSchema:
516516
element_type: "BasicValueType"
517517
dimension: int | None
518518

519+
def __str__(self) -> str:
520+
dimension_str = f", {self.dimension}" if self.dimension is not None else ""
521+
return f"Vector[{self.element_type}{dimension_str}]"
522+
523+
def __repr__(self) -> str:
524+
return self.__str__()
525+
519526
@staticmethod
520527
def decode(obj: dict[str, Any]) -> "VectorTypeSchema":
521528
return VectorTypeSchema(
@@ -534,6 +541,13 @@ def encode(self) -> dict[str, Any]:
534541
class UnionTypeSchema:
535542
variants: list["BasicValueType"]
536543

544+
def __str__(self) -> str:
545+
types_str = " | ".join(str(t) for t in self.variants)
546+
return f"Union[{types_str}]"
547+
548+
def __repr__(self) -> str:
549+
return self.__str__()
550+
537551
@staticmethod
538552
def decode(obj: dict[str, Any]) -> "UnionTypeSchema":
539553
return UnionTypeSchema(
@@ -573,6 +587,19 @@ class BasicValueType:
573587
vector: VectorTypeSchema | None = None
574588
union: UnionTypeSchema | None = None
575589

590+
def __str__(self) -> str:
591+
if self.kind == "Vector" and self.vector is not None:
592+
dimension_str = f", {self.vector.dimension}" if self.vector.dimension is not None else ""
593+
return f"Vector[{self.vector.element_type}{dimension_str}]"
594+
elif self.kind == "Union" and self.union is not None:
595+
types_str = " | ".join(str(t) for t in self.union.variants)
596+
return f"Union[{types_str}]"
597+
else:
598+
return self.kind
599+
600+
def __repr__(self) -> str:
601+
return self.__str__()
602+
576603
@staticmethod
577604
def decode(obj: dict[str, Any]) -> "BasicValueType":
578605
kind = obj["kind"]
@@ -603,6 +630,18 @@ class EnrichedValueType:
603630
nullable: bool = False
604631
attrs: dict[str, Any] | None = None
605632

633+
def __str__(self) -> str:
634+
result = str(self.type)
635+
if self.nullable:
636+
result += "?"
637+
if self.attrs:
638+
attrs_str = ", ".join(f"{k}: {v}" for k, v in self.attrs.items())
639+
result += f" [{attrs_str}]"
640+
return result
641+
642+
def __repr__(self) -> str:
643+
return self.__str__()
644+
606645
@staticmethod
607646
def decode(obj: dict[str, Any]) -> "EnrichedValueType":
608647
return EnrichedValueType(
@@ -625,6 +664,12 @@ class FieldSchema:
625664
name: str
626665
value_type: EnrichedValueType
627666

667+
def __str__(self) -> str:
668+
return f"{self.name}: {self.value_type}"
669+
670+
def __repr__(self) -> str:
671+
return self.__str__()
672+
628673
@staticmethod
629674
def decode(obj: dict[str, Any]) -> "FieldSchema":
630675
return FieldSchema(name=obj["name"], value_type=EnrichedValueType.decode(obj))
@@ -640,6 +685,13 @@ class StructSchema:
640685
fields: list[FieldSchema]
641686
description: str | None = None
642687

688+
def __str__(self) -> str:
689+
fields_str = ", ".join(str(field) for field in self.fields)
690+
return f"Struct({fields_str})"
691+
692+
def __repr__(self) -> str:
693+
return self.__str__()
694+
643695
@classmethod
644696
def decode(cls, obj: dict[str, Any]) -> Self:
645697
return cls(
@@ -658,6 +710,13 @@ def encode(self) -> dict[str, Any]:
658710
class StructType(StructSchema):
659711
kind: Literal["Struct"] = "Struct"
660712

713+
def __str__(self) -> str:
714+
# Use the parent's __str__ method for consistency
715+
return super().__str__()
716+
717+
def __repr__(self) -> str:
718+
return self.__str__()
719+
661720
def encode(self) -> dict[str, Any]:
662721
result = super().encode()
663722
result["kind"] = self.kind
@@ -666,10 +725,24 @@ def encode(self) -> dict[str, Any]:
666725

667726
@dataclasses.dataclass
668727
class TableType:
669-
kind: Literal["KTable", "LTable"]
728+
kind: Literal["UTable", "KTable", "LTable"]
670729
row: StructSchema
671730
num_key_parts: int | None = None # Only for KTable
672731

732+
def __str__(self) -> str:
733+
if self.kind == "KTable":
734+
num_parts = self.num_key_parts if self.num_key_parts is not None else 1
735+
table_kind = f"KTable({num_parts})"
736+
elif self.kind == "LTable":
737+
table_kind = "LTable"
738+
else: # UTable
739+
table_kind = "Table"
740+
741+
return f"{table_kind}({self.row})"
742+
743+
def __repr__(self) -> str:
744+
return self.__str__()
745+
673746
@staticmethod
674747
def decode(obj: dict[str, Any]) -> "TableType":
675748
row_obj = obj["row"]

0 commit comments

Comments
 (0)