Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions python/cocoindex/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,15 @@ def _create_data_slice(
def _spec_kind(spec: Any) -> str:
return spec.__class__.__name__

def _spec_dump(spec: Any) -> dict[str, Any]:
return spec.__dict__
def _spec_value_dump(spec: Any) -> Any:
"""Recursively dump a spec object and its nested attributes to a dictionary."""
if hasattr(spec, '__dict__'):
return {k: _spec_value_dump(v) for k, v in spec.__dict__.items()}
elif isinstance(spec, (list, tuple)):
return [_spec_value_dump(item) for item in spec]
elif isinstance(spec, dict):
return {k: _spec_value_dump(v) for k, v in spec.items()}
return spec

T = TypeVar('T')

Expand Down Expand Up @@ -161,7 +168,7 @@ def transform(self, fn_spec: op.FunctionSpec, /, name: str | None = None) -> Dat
lambda target_scope, name:
flow_builder_state.engine_flow_builder.transform(
_spec_kind(fn_spec),
_spec_dump(fn_spec),
_spec_value_dump(fn_spec),
args,
target_scope,
flow_builder_state.field_name_builder.build_name(
Expand Down Expand Up @@ -252,7 +259,7 @@ def export(self, name: str, target_spec: op.StorageSpec, /, *,
{"field_name": field_name, "metric": metric.value}
for field_name, metric in vector_index]
self._flow_builder_state.engine_flow_builder.export(
name, _spec_kind(target_spec), _spec_dump(target_spec),
name, _spec_kind(target_spec), _spec_value_dump(target_spec),
index_options, self._engine_data_collector)


Expand Down Expand Up @@ -293,7 +300,7 @@ def add_source(self, spec: op.SourceSpec, /, name: str | None = None) -> DataSli
self._state,
lambda target_scope, name: self._state.engine_flow_builder.add_source(
_spec_kind(spec),
_spec_dump(spec),
_spec_value_dump(spec),
target_scope,
self._state.field_name_builder.build_name(
name, prefix=_to_snake_case(_spec_kind(spec))+'_'),
Expand Down
15 changes: 15 additions & 0 deletions python/cocoindex/functions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""All builtin functions."""
from dataclasses import dataclass
from typing import Annotated, Any

import sentence_transformers
Expand All @@ -11,6 +12,20 @@ class SplitRecursively(op.FunctionSpec):
chunk_overlap: int
language: str | None = None

@dataclass
class MistralModelSpec:
"""A specification for a Mistral model."""
model_id: str
isq_type: str

class ExtractByMistral(op.FunctionSpec):
"""Extract information from a text using a Mistral model."""

model: MistralModelSpec
# Expected to be generated by `cocoindex.typing.encode_enriched_type()`
output_type: dict[str, Any]
instructions: str | None = None

class SentenceTransformerEmbed(op.FunctionSpec):
"""
`SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library.
Expand Down
9 changes: 8 additions & 1 deletion python/cocoindex/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import collections
import dataclasses
import types
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING
from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload

class Vector(NamedTuple):
dim: int | None
Expand Down Expand Up @@ -182,6 +182,13 @@ def encode_enriched_type_info(enriched_type_info: AnalyzedTypeInfo) -> dict[str,

return encoded

@overload
def encode_enriched_type(t: None) -> None:
...

@overload
def encode_enriched_type(t: Any) -> dict[str, Any]:
...

def encode_enriched_type(t) -> dict[str, Any] | None:
"""
Expand Down