diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 10a820be5..0edad24c3 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,6 +26,21 @@ permissions: contents: read jobs: + format-check: + name: Check Python formatting + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.11 + - name: Install Ruff + run: | + pip install ruff + - name: Check Python formatting + run: | + ruff format --check . + test: name: Run test uses: ./.github/workflows/_test.yml diff --git a/.vscode/settings.json b/.vscode/settings.json index 795857a39..34129185f 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -3,5 +3,7 @@ "cocoindex", "reindexing", "timedelta" - ] + ], + "editor.formatOnSave": true, + "python.formatting.provider": "ruff" } \ No newline at end of file diff --git a/examples/amazon_s3_embedding/main.py b/examples/amazon_s3_embedding/main.py index 82bc242a0..6d2f878bd 100644 --- a/examples/amazon_s3_embedding/main.py +++ b/examples/amazon_s3_embedding/main.py @@ -3,8 +3,11 @@ import cocoindex import os + @cocoindex.flow_def(name="AmazonS3TextEmbedding") -def amazon_s3_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def amazon_s3_text_embedding_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that embeds text from Amazon S3 into a vector database. """ @@ -18,21 +21,32 @@ def amazon_s3_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scop prefix=prefix, included_patterns=["*.md", "*.txt", "*.docx"], binary=False, - sqs_queue_url=sqs_queue_url)) + sqs_queue_url=sqs_queue_url, + ) + ) doc_embeddings = data_scope.add_collector() with data_scope["documents"].row() as doc: doc["chunks"] = doc["content"].transform( cocoindex.functions.SplitRecursively(), - language="markdown", chunk_size=2000, chunk_overlap=500) + language="markdown", + chunk_size=2000, + chunk_overlap=500, + ) with doc["chunks"].row() as chunk: chunk["embedding"] = chunk["text"].transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) - doc_embeddings.collect(filename=doc["filename"], location=chunk["location"], - text=chunk["text"], embedding=chunk["embedding"]) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + doc_embeddings.collect( + filename=doc["filename"], + location=chunk["location"], + text=chunk["text"], + embedding=chunk["embedding"], + ) doc_embeddings.export( "doc_embeddings", @@ -41,7 +55,11 @@ def amazon_s3_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scop vector_indexes=[ cocoindex.VectorIndexDef( field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)]) + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ) + ], + ) + query_handler = cocoindex.query.SimpleSemanticsQueryHandler( name="SemanticsSearch", @@ -49,8 +67,12 @@ def amazon_s3_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scop target_name="doc_embeddings", query_transform_flow=lambda text: text.transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")), - default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ), + default_similarity_metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, +) + def _main(): # Use a `FlowLiveUpdater` to keep the flow data updated. @@ -58,7 +80,7 @@ def _main(): # Run queries in a loop to demonstrate the query capabilities. while True: query = input("Enter search query (or Enter to quit): ") - if query == '': + if query == "": break results, _ = query_handler.search(query, 10) print("\nSearch results:") @@ -68,6 +90,7 @@ def _main(): print("---") print() + if __name__ == "__main__": load_dotenv() cocoindex.init() diff --git a/examples/code_embedding/main.py b/examples/code_embedding/main.py index 01e3d75b3..43d2e9f62 100644 --- a/examples/code_embedding/main.py +++ b/examples/code_embedding/main.py @@ -3,40 +3,59 @@ import cocoindex import os + @cocoindex.op.function() def extract_extension(filename: str) -> str: """Extract the extension of a filename.""" return os.path.splitext(filename)[1] + @cocoindex.transform_flow() -def code_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]: +def code_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[list[float]]: """ Embed the text using a SentenceTransformer model. """ return text.transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + @cocoindex.flow_def(name="CodeEmbedding") -def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def code_embedding_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that embeds files into a vector database. """ data_scope["files"] = flow_builder.add_source( - cocoindex.sources.LocalFile(path="../..", - included_patterns=["*.py", "*.rs", "*.toml", "*.md", "*.mdx"], - excluded_patterns=["**/.*", "target", "**/node_modules"])) + cocoindex.sources.LocalFile( + path="../..", + included_patterns=["*.py", "*.rs", "*.toml", "*.md", "*.mdx"], + excluded_patterns=["**/.*", "target", "**/node_modules"], + ) + ) code_embeddings = data_scope.add_collector() with data_scope["files"].row() as file: file["extension"] = file["filename"].transform(extract_extension) file["chunks"] = file["content"].transform( cocoindex.functions.SplitRecursively(), - language=file["extension"], chunk_size=1000, chunk_overlap=300) + language=file["extension"], + chunk_size=1000, + chunk_overlap=300, + ) with file["chunks"].row() as chunk: chunk["embedding"] = chunk["text"].call(code_to_embedding) - code_embeddings.collect(filename=file["filename"], location=chunk["location"], - code=chunk["text"], embedding=chunk["embedding"]) + code_embeddings.collect( + filename=file["filename"], + location=chunk["location"], + code=chunk["text"], + embedding=chunk["embedding"], + ) code_embeddings.export( "code_embeddings", @@ -45,26 +64,35 @@ def code_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind vector_indexes=[ cocoindex.VectorIndexDef( field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)]) + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ) + ], + ) def search(pool: ConnectionPool, query: str, top_k: int = 5): # Get the table name, for the export target in the code_embedding_flow above. - table_name = cocoindex.utils.get_target_storage_default_name(code_embedding_flow, "code_embeddings") + table_name = cocoindex.utils.get_target_storage_default_name( + code_embedding_flow, "code_embeddings" + ) # Evaluate the transform flow defined above with the input query, to get the embedding. query_vector = code_to_embedding.eval(query) # Run the query and get the results. with pool.connection() as conn: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT filename, code, embedding <=> %s::vector AS distance FROM {table_name} ORDER BY distance LIMIT %s - """, (query_vector, top_k)) + """, + (query_vector, top_k), + ) return [ {"filename": row[0], "code": row[1], "score": 1.0 - row[2]} for row in cur.fetchall() ] + def _main(): # Make sure the flow is built and up-to-date. stats = code_embedding_flow.update() @@ -75,7 +103,7 @@ def _main(): # Run queries in a loop to demonstrate the query capabilities. while True: query = input("Enter search query (or Enter to quit): ") - if query == '': + if query == "": break # Run the query function with the database connection pool and the query. results = search(pool, query) @@ -86,6 +114,7 @@ def _main(): print("---") print() + if __name__ == "__main__": load_dotenv() cocoindex.init() diff --git a/examples/docs_to_knowledge_graph/main.py b/examples/docs_to_knowledge_graph/main.py index ef4b9ed2e..f9dcc0cd1 100644 --- a/examples/docs_to_knowledge_graph/main.py +++ b/examples/docs_to_knowledge_graph/main.py @@ -1,27 +1,35 @@ """ This example shows how to extract relationships from documents and build a knowledge graph. """ + import dataclasses import cocoindex + @dataclasses.dataclass class DocumentSummary: """Describe a summary of a document.""" + title: str summary: str + @dataclasses.dataclass class Relationship: """ Describe a relationship between two entities. Subject and object should be Core CocoIndex concepts only, should be nouns. For example, `CocoIndex`, `Incremental Processing`, `ETL`, `Data` etc. """ + subject: str predicate: str object: str + @cocoindex.flow_def(name="DocsToKG") -def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def docs_to_kg_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that extracts relationship from files and build knowledge graph. """ @@ -32,11 +40,14 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D uri="bolt://localhost:7687", user="neo4j", password="cocoindex", - )) + ), + ) data_scope["documents"] = flow_builder.add_source( - cocoindex.sources.LocalFile(path="../../docs/docs/core", - included_patterns=["*.md", "*.mdx"])) + cocoindex.sources.LocalFile( + path="../../docs/docs/core", included_patterns=["*.md", "*.mdx"] + ) + ) document_node = data_scope.add_collector() entity_relationship = data_scope.add_collector() @@ -48,24 +59,34 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D cocoindex.functions.ExtractByLlm( llm_spec=cocoindex.LlmSpec( # Supported LLM: https://cocoindex.io/docs/ai/llm - api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"), + api_type=cocoindex.LlmApiType.OPENAI, + model="gpt-4o", + ), output_type=DocumentSummary, - instruction="Please summarize the content of the document.")) + instruction="Please summarize the content of the document.", + ) + ) document_node.collect( - filename=doc["filename"], title=doc["summary"]["title"], - summary=doc["summary"]["summary"]) + filename=doc["filename"], + title=doc["summary"]["title"], + summary=doc["summary"]["summary"], + ) # extract relationships from document doc["relationships"] = doc["content"].transform( cocoindex.functions.ExtractByLlm( llm_spec=cocoindex.LlmSpec( # Supported LLM: https://cocoindex.io/docs/ai/llm - api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"), - output_type=list[Relationship], - instruction=( - "Please extract relationships from CocoIndex documents. " - "Focus on concepts and ignore examples and code. " - ))) + api_type=cocoindex.LlmApiType.OPENAI, + model="gpt-4o", + ), + output_type=list[Relationship], + instruction=( + "Please extract relationships from CocoIndex documents. " + "Focus on concepts and ignore examples and code. " + ), + ) + ) with doc["relationships"].row() as relationship: # relationship between two entities @@ -77,22 +98,23 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D ) # mention of an entity in a document, for subject entity_mention.collect( - id=cocoindex.GeneratedField.UUID, entity=relationship["subject"], + id=cocoindex.GeneratedField.UUID, + entity=relationship["subject"], filename=doc["filename"], ) # mention of an entity in a document, for object entity_mention.collect( - id=cocoindex.GeneratedField.UUID, entity=relationship["object"], + id=cocoindex.GeneratedField.UUID, + entity=relationship["object"], filename=doc["filename"], ) - # export to neo4j document_node.export( "document_node", cocoindex.storages.Neo4j( - connection=conn_spec, - mapping=cocoindex.storages.Nodes(label="Document")), + connection=conn_spec, mapping=cocoindex.storages.Nodes(label="Document") + ), primary_key_fields=["filename"], ) # Declare reference Node to reference entity node in a relationship @@ -113,15 +135,17 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D label="Entity", fields=[ cocoindex.storages.TargetFieldMapping( - source="subject", target="value"), - ] + source="subject", target="value" + ), + ], ), target=cocoindex.storages.NodeFromFields( label="Entity", fields=[ cocoindex.storages.TargetFieldMapping( - source="object", target="value"), - ] + source="object", target="value" + ), + ], ), ), ), @@ -139,8 +163,11 @@ def docs_to_kg_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.D ), target=cocoindex.storages.NodeFromFields( label="Entity", - fields=[cocoindex.storages.TargetFieldMapping( - source="entity", target="value")], + fields=[ + cocoindex.storages.TargetFieldMapping( + source="entity", target="value" + ) + ], ), ), ), diff --git a/examples/fastapi_server_docker/main.py b/examples/fastapi_server_docker/main.py index 11ba61d06..6535ecabd 100644 --- a/examples/fastapi_server_docker/main.py +++ b/examples/fastapi_server_docker/main.py @@ -5,37 +5,49 @@ from psycopg_pool import ConnectionPool import os + @cocoindex.transform_flow() -def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]: +def text_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[list[float]]: """ Embed the text using a SentenceTransformer model. This is a shared logic between indexing and querying. """ return text.transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + @cocoindex.flow_def(name="MarkdownEmbeddingFastApiExample") -def markdown_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def markdown_embedding_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that embeds markdown files into a vector database. """ data_scope["documents"] = flow_builder.add_source( - cocoindex.sources.LocalFile(path="files")) + cocoindex.sources.LocalFile(path="files") + ) doc_embeddings = data_scope.add_collector() with data_scope["documents"].row() as doc: doc["chunks"] = doc["content"].transform( cocoindex.functions.SplitRecursively(), - language="markdown", chunk_size=2000, chunk_overlap=500) - + language="markdown", + chunk_size=2000, + chunk_overlap=500, + ) + with doc["chunks"].row() as chunk: chunk["embedding"] = text_to_embedding(chunk["text"]) doc_embeddings.collect( - filename=doc["filename"], + filename=doc["filename"], location=chunk["location"], - text=chunk["text"], - embedding=chunk["embedding"] + text=chunk["text"], + embedding=chunk["embedding"], ) doc_embeddings.export( @@ -45,27 +57,38 @@ def markdown_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coc vector_indexes=[ cocoindex.VectorIndexDef( field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)]) + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ) + ], + ) + def search(pool: ConnectionPool, query: str, top_k: int = 5): # Get the table name, for the export target in the text_embedding_flow above. - table_name = cocoindex.utils.get_target_storage_default_name(markdown_embedding_flow, "doc_embeddings") + table_name = cocoindex.utils.get_target_storage_default_name( + markdown_embedding_flow, "doc_embeddings" + ) # Evaluate the transform flow defined above with the input query, to get the embedding. query_vector = text_to_embedding.eval(query) # Run the query and get the results. with pool.connection() as conn: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT filename, text, embedding <=> %s::vector AS distance FROM {table_name} ORDER BY distance LIMIT %s - """, (query_vector, top_k)) + """, + (query_vector, top_k), + ) return [ {"filename": row[0], "text": row[1], "score": 1.0 - row[2]} for row in cur.fetchall() ] + fastapi_app = FastAPI() + @fastapi_app.on_event("startup") def startup_event(): load_dotenv() @@ -73,11 +96,16 @@ def startup_event(): # Initialize database connection pool fastapi_app.state.pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL")) + @fastapi_app.get("/search") -def search_endpoint(q: str = Query(..., description="Search query"), limit: int = Query(5, description="Number of results")): +def search_endpoint( + q: str = Query(..., description="Search query"), + limit: int = Query(5, description="Number of results"), +): results = search(fastapi_app.state.pool, q, limit) return {"results": results} + if __name__ == "__main__": load_dotenv() cocoindex.init() diff --git a/examples/gdrive_text_embedding/main.py b/examples/gdrive_text_embedding/main.py index 0a057049d..5fcde61eb 100644 --- a/examples/gdrive_text_embedding/main.py +++ b/examples/gdrive_text_embedding/main.py @@ -4,18 +4,26 @@ import datetime import os + @cocoindex.transform_flow() -def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]: +def text_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[list[float]]: """ Embed the text using a SentenceTransformer model. This is a shared logic between indexing and querying, so extract it as a function. """ return text.transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + @cocoindex.flow_def(name="GoogleDriveTextEmbedding") -def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def gdrive_text_embedding_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that embeds text into a vector database. """ @@ -26,20 +34,29 @@ def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.sources.GoogleDrive( service_account_credential_path=credential_path, root_folder_ids=root_folder_ids, - recent_changes_poll_interval=datetime.timedelta(seconds=10)), - refresh_interval=datetime.timedelta(minutes=1)) + recent_changes_poll_interval=datetime.timedelta(seconds=10), + ), + refresh_interval=datetime.timedelta(minutes=1), + ) doc_embeddings = data_scope.add_collector() with data_scope["documents"].row() as doc: doc["chunks"] = doc["content"].transform( cocoindex.functions.SplitRecursively(), - language="markdown", chunk_size=2000, chunk_overlap=500) + language="markdown", + chunk_size=2000, + chunk_overlap=500, + ) with doc["chunks"].row() as chunk: chunk["embedding"] = text_to_embedding(chunk["text"]) - doc_embeddings.collect(filename=doc["filename"], location=chunk["location"], - text=chunk["text"], embedding=chunk["embedding"]) + doc_embeddings.collect( + filename=doc["filename"], + location=chunk["location"], + text=chunk["text"], + embedding=chunk["embedding"], + ) doc_embeddings.export( "doc_embeddings", @@ -48,32 +65,42 @@ def gdrive_text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: vector_indexes=[ cocoindex.VectorIndexDef( field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)]) + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ) + ], + ) + def search(pool: ConnectionPool, query: str, top_k: int = 5): # Get the table name, for the export target in the gdrive_text_embedding_flow above. - table_name = cocoindex.utils.get_target_storage_default_name(gdrive_text_embedding_flow, "doc_embeddings") + table_name = cocoindex.utils.get_target_storage_default_name( + gdrive_text_embedding_flow, "doc_embeddings" + ) # Evaluate the transform flow defined above with the input query, to get the embedding. query_vector = text_to_embedding.eval(query) # Run the query and get the results. with pool.connection() as conn: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT filename, text, embedding <=> %s::vector AS distance FROM {table_name} ORDER BY distance LIMIT %s - """, (query_vector, top_k)) + """, + (query_vector, top_k), + ) return [ {"filename": row[0], "text": row[1], "score": 1.0 - row[2]} for row in cur.fetchall() ] + def _main(): # Initialize the database connection pool. pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL")) # Run queries in a loop to demonstrate the query capabilities. while True: query = input("Enter search query (or Enter to quit): ") - if query == '': + if query == "": break # Run the query function with the database connection pool and the query. results = search(pool, query) @@ -84,6 +111,7 @@ def _main(): print("---") print() + if __name__ == "__main__": load_dotenv() cocoindex.init() diff --git a/examples/image_search/main.py b/examples/image_search/main.py index d0b331224..cdb37dffc 100644 --- a/examples/image_search/main.py +++ b/examples/image_search/main.py @@ -20,6 +20,7 @@ QDRANT_GRPC_URL = os.getenv("QDRANT_GRPC_URL", "http://localhost:6334/") CLIP_MODEL_NAME = "openai/clip-vit-large-patch14" + @functools.cache def get_clip_model() -> tuple[CLIPModel, CLIPProcessor]: model = CLIPModel.from_pretrained(CLIP_MODEL_NAME) @@ -49,14 +50,20 @@ def embed_image(img_bytes: bytes) -> cocoindex.Vector[cocoindex.Float32, Literal with torch.no_grad(): features = model.get_image_features(**inputs) return features[0].tolist() - + # CocoIndex flow: Ingest images, extract captions, embed, export to Qdrant @cocoindex.flow_def(name="ImageObjectEmbedding") -def image_object_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def image_object_embedding_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): data_scope["images"] = flow_builder.add_source( - cocoindex.sources.LocalFile(path="img", included_patterns=["*.jpg", "*.jpeg", "*.png"], binary=True), - refresh_interval=datetime.timedelta(minutes=1) # Poll for changes every 1 minute + cocoindex.sources.LocalFile( + path="img", included_patterns=["*.jpg", "*.jpeg", "*.png"], binary=True + ), + refresh_interval=datetime.timedelta( + minutes=1 + ), # Poll for changes every 1 minute ) img_embeddings = data_scope.add_collector() with data_scope["images"].row() as img: @@ -76,6 +83,7 @@ def image_object_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: setup_by_user=True, ) + # --- FastAPI app for web API --- app = FastAPI() app.add_middleware( @@ -88,36 +96,35 @@ def image_object_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: # Serve images from the 'img' directory at /img app.mount("/img", StaticFiles(directory="img"), name="img") + # --- CocoIndex initialization on startup --- @app.on_event("startup") def startup_event(): load_dotenv() cocoindex.init() # Initialize Qdrant client - app.state.qdrant_client = QdrantClient( - url=QDRANT_GRPC_URL, - prefer_grpc=True - ) + app.state.qdrant_client = QdrantClient(url=QDRANT_GRPC_URL, prefer_grpc=True) app.state.live_updater = cocoindex.FlowLiveUpdater(image_object_embedding_flow) app.state.live_updater.start() + @app.get("/search") -def search(q: str = Query(..., description="Search query"), limit: int = Query(5, description="Number of results")): +def search( + q: str = Query(..., description="Search query"), + limit: int = Query(5, description="Number of results"), +): # Get the embedding for the query query_embedding = embed_query(q) - + # Search in Qdrant search_results = app.state.qdrant_client.search( collection_name="image_search", query_vector=("embedding", query_embedding), - limit=limit + limit=limit, ) - + # Format results out = [] for result in search_results: - out.append({ - "filename": result.payload["filename"], - "score": result.score - }) + out.append({"filename": result.payload["filename"], "score": result.score}) return {"results": out} diff --git a/examples/manuals_llm_extraction/main.py b/examples/manuals_llm_extraction/main.py index 31d57f050..17be9713c 100644 --- a/examples/manuals_llm_extraction/main.py +++ b/examples/manuals_llm_extraction/main.py @@ -8,9 +8,11 @@ import cocoindex + class PdfToMarkdown(cocoindex.op.FunctionSpec): """Convert a PDF to markdown.""" + @cocoindex.op.executor_class(gpu=True, cache=True, behavior_version=1) class PdfToMarkdownExecutor: """Executor for PdfToMarkdown.""" @@ -20,7 +22,9 @@ class PdfToMarkdownExecutor: def prepare(self): config_parser = ConfigParser({}) - self._converter = PdfConverter(create_model_dict(), config=config_parser.generate_config_dict()) + self._converter = PdfConverter( + create_model_dict(), config=config_parser.generate_config_dict() + ) def __call__(self, content: bytes) -> str: with tempfile.NamedTemporaryFile(delete=True, suffix=".pdf") as temp_file: @@ -29,40 +33,51 @@ def __call__(self, content: bytes) -> str: text, _, _ = text_from_rendered(self._converter(temp_file.name)) return text + @dataclasses.dataclass class ArgInfo: """Information about an argument of a method.""" + name: str description: str + @dataclasses.dataclass class MethodInfo: """Information about a method.""" + name: str args: list[ArgInfo] description: str + @dataclasses.dataclass class ClassInfo: """Information about a class.""" + name: str description: str methods: list[MethodInfo] + @dataclasses.dataclass class ModuleInfo: """Information about a Python module.""" + title: str description: str classes: list[ClassInfo] methods: list[MethodInfo] + @dataclasses.dataclass class ModuleSummary: """Summary info about a Python module.""" + num_classes: int num_methods: int + @cocoindex.op.function() def summarize_module(module_info: ModuleInfo) -> ModuleSummary: """Summarize a Python module.""" @@ -71,12 +86,17 @@ def summarize_module(module_info: ModuleInfo) -> ModuleSummary: num_methods=len(module_info.methods), ) + @cocoindex.flow_def(name="ManualExtraction") -def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def manual_extraction_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that extracts manual information from a Markdown. """ - data_scope["documents"] = flow_builder.add_source(cocoindex.sources.LocalFile(path="manuals", binary=True)) + data_scope["documents"] = flow_builder.add_source( + cocoindex.sources.LocalFile(path="manuals", binary=True) + ) modules_index = data_scope.add_collector() @@ -85,24 +105,23 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco doc["module_info"] = doc["markdown"].transform( cocoindex.functions.ExtractByLlm( llm_spec=cocoindex.LlmSpec( - api_type=cocoindex.LlmApiType.OLLAMA, - # See the full list of models: https://ollama.com/library - model="llama3.2" + api_type=cocoindex.LlmApiType.OLLAMA, + # See the full list of models: https://ollama.com/library + model="llama3.2", ), - # Replace by this spec below, to use OpenAI API model instead of ollama # llm_spec=cocoindex.LlmSpec( # api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4o"), - # Replace by this spec below, to use Gemini API model # llm_spec=cocoindex.LlmSpec( # api_type=cocoindex.LlmApiType.GEMINI, model="gemini-2.0-flash"), - # Replace by this spec below, to use Anthropic API model # llm_spec=cocoindex.LlmSpec( # api_type=cocoindex.LlmApiType.ANTHROPIC, model="claude-3-5-sonnet-latest"), output_type=ModuleInfo, - instruction="Please extract Python module information from the manual.")) + instruction="Please extract Python module information from the manual.", + ) + ) doc["module_summary"] = doc["module_info"].transform(summarize_module) modules_index.collect( filename=doc["filename"], @@ -114,4 +133,4 @@ def manual_extraction_flow(flow_builder: cocoindex.FlowBuilder, data_scope: coco "modules", cocoindex.storages.Postgres(table_name="modules_info"), primary_key_fields=["filename"], - ) \ No newline at end of file + ) diff --git a/examples/pdf_embedding/main.py b/examples/pdf_embedding/main.py index 83088d40d..b637e884c 100644 --- a/examples/pdf_embedding/main.py +++ b/examples/pdf_embedding/main.py @@ -25,7 +25,9 @@ class PdfToMarkdownExecutor: def prepare(self): config_parser = ConfigParser({}) - self._converter = PdfConverter(create_model_dict(), config=config_parser.generate_config_dict()) + self._converter = PdfConverter( + create_model_dict(), config=config_parser.generate_config_dict() + ) def __call__(self, content: bytes) -> str: with tempfile.NamedTemporaryFile(delete=True, suffix=".pdf") as temp_file: @@ -36,22 +38,30 @@ def __call__(self, content: bytes) -> str: @cocoindex.transform_flow() -def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]: +def text_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[list[float]]: """ Embed the text using a SentenceTransformer model. This is a shared logic between indexing and querying, so extract it as a function. """ return text.transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) @cocoindex.flow_def(name="PdfEmbedding") -def pdf_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def pdf_embedding_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that embeds files into a vector database. """ - data_scope["documents"] = flow_builder.add_source(cocoindex.sources.LocalFile(path="pdf_files", binary=True)) + data_scope["documents"] = flow_builder.add_source( + cocoindex.sources.LocalFile(path="pdf_files", binary=True) + ) pdf_embeddings = data_scope.add_collector() @@ -59,13 +69,20 @@ def pdf_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde doc["markdown"] = doc["content"].transform(PdfToMarkdown()) doc["chunks"] = doc["markdown"].transform( cocoindex.functions.SplitRecursively(), - language="markdown", chunk_size=2000, chunk_overlap=500) + language="markdown", + chunk_size=2000, + chunk_overlap=500, + ) with doc["chunks"].row() as chunk: chunk["embedding"] = text_to_embedding(chunk["text"]) - pdf_embeddings.collect(id=cocoindex.GeneratedField.UUID, - filename=doc["filename"], location=chunk["location"], - text=chunk["text"], embedding=chunk["embedding"]) + pdf_embeddings.collect( + id=cocoindex.GeneratedField.UUID, + filename=doc["filename"], + location=chunk["location"], + text=chunk["text"], + embedding=chunk["embedding"], + ) pdf_embeddings.export( "pdf_embeddings", @@ -74,21 +91,29 @@ def pdf_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde vector_indexes=[ cocoindex.VectorIndexDef( field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)]) + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ) + ], + ) def search(pool: ConnectionPool, query: str, top_k: int = 5): # Get the table name, for the export target in the pdf_embedding_flow above. - table_name = cocoindex.utils.get_target_storage_default_name(pdf_embedding_flow, "pdf_embeddings") + table_name = cocoindex.utils.get_target_storage_default_name( + pdf_embedding_flow, "pdf_embeddings" + ) # Evaluate the transform flow defined above with the input query, to get the embedding. query_vector = text_to_embedding.eval(query) # Run the query and get the results. with pool.connection() as conn: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT filename, text, embedding <=> %s::vector AS distance FROM {table_name} ORDER BY distance LIMIT %s - """, (query_vector, top_k)) + """, + (query_vector, top_k), + ) return [ {"filename": row[0], "text": row[1], "score": 1.0 - row[2]} for row in cur.fetchall() @@ -112,7 +137,7 @@ def _main(): # Run queries in a loop to demonstrate the query capabilities. while True: query = input("Enter search query (or Enter to quit): ") - if query == '': + if query == "": break # Run the query function with the database connection pool and the query. results = search(pool, query) diff --git a/examples/product_recommendation/main.py b/examples/product_recommendation/main.py index 63af3bb77..7580e33d3 100644 --- a/examples/product_recommendation/main.py +++ b/examples/product_recommendation/main.py @@ -1,6 +1,7 @@ """ This example shows how to extract relationships from Markdown documents and build a knowledge graph. """ + import dataclasses import datetime import cocoindex @@ -25,6 +26,7 @@ """ + @dataclasses.dataclass class ProductInfo: id: str @@ -32,11 +34,12 @@ class ProductInfo: price: float detail: str + @dataclasses.dataclass class ProductTaxonomy: """ Taxonomy for the product. - + A taxonomy is a concise noun (or short noun phrase), based on its core functionality, without specific details such as branding, style, etc. Always use the most common words in US English. @@ -45,8 +48,10 @@ class ProductTaxonomy: A product may have multiple taxonomies. Avoid large categories like "office supplies" or "electronics". Use specific ones, like "pen" or "printer". """ + name: str + @dataclasses.dataclass class ProductTaxonomyInfo: """ @@ -56,9 +61,11 @@ class ProductTaxonomyInfo: - taxonomies: Taxonomies for the current product. - complementary_taxonomies: Think about when customers buy this product, what else they might need as complementary products. Put labels for these complentary products. """ + taxonomies: list[ProductTaxonomy] complementary_taxonomies: list[ProductTaxonomy] + @cocoindex.op.function(behavior_version=2) def extract_product_info(product: cocoindex.Json, filename: str) -> ProductInfo: # Print markdown for LLM to extract the taxonomy and complimentary taxonomy @@ -71,35 +78,49 @@ def extract_product_info(product: cocoindex.Json, filename: str) -> ProductInfo: @cocoindex.flow_def(name="StoreProduct") -def store_product_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def store_product_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that extracts triples from files and build knowledge graph. """ data_scope["products"] = flow_builder.add_source( - cocoindex.sources.LocalFile(path="products", - included_patterns=["*.json"]), - refresh_interval=datetime.timedelta(seconds=5)) + cocoindex.sources.LocalFile(path="products", included_patterns=["*.json"]), + refresh_interval=datetime.timedelta(seconds=5), + ) product_node = data_scope.add_collector() product_taxonomy = data_scope.add_collector() product_complementary_taxonomy = data_scope.add_collector() - with data_scope["products"].row() as product: - data = (product["content"].transform(cocoindex.functions.ParseJson(), language="json") - .transform(extract_product_info, filename=product["filename"])) - taxonomy = data["detail"].transform(cocoindex.functions.ExtractByLlm( - llm_spec=cocoindex.LlmSpec( - api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4.1"), - output_type=ProductTaxonomyInfo)) - - product_node.collect(id=data["id"], title=data["title"], price=data["price"]) - with taxonomy['taxonomies'].row() as t: - product_taxonomy.collect(id=cocoindex.GeneratedField.UUID, product_id=data["id"], taxonomy=t["name"]) - with taxonomy['complementary_taxonomies'].row() as t: - product_complementary_taxonomy.collect(id=cocoindex.GeneratedField.UUID, product_id=data["id"], taxonomy=t["name"]) - + data = ( + product["content"] + .transform(cocoindex.functions.ParseJson(), language="json") + .transform(extract_product_info, filename=product["filename"]) + ) + taxonomy = data["detail"].transform( + cocoindex.functions.ExtractByLlm( + llm_spec=cocoindex.LlmSpec( + api_type=cocoindex.LlmApiType.OPENAI, model="gpt-4.1" + ), + output_type=ProductTaxonomyInfo, + ) + ) + product_node.collect(id=data["id"], title=data["title"], price=data["price"]) + with taxonomy["taxonomies"].row() as t: + product_taxonomy.collect( + id=cocoindex.GeneratedField.UUID, + product_id=data["id"], + taxonomy=t["name"], + ) + with taxonomy["complementary_taxonomies"].row() as t: + product_complementary_taxonomy.collect( + id=cocoindex.GeneratedField.UUID, + product_id=data["id"], + taxonomy=t["name"], + ) conn_spec = cocoindex.add_auth_entry( "Neo4jConnection", @@ -107,13 +128,13 @@ def store_product_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde uri="bolt://localhost:7687", user="neo4j", password="cocoindex", - )) + ), + ) product_node.export( "product_node", cocoindex.storages.Neo4j( - connection=conn_spec, - mapping=cocoindex.storages.Nodes(label="Product") + connection=conn_spec, mapping=cocoindex.storages.Nodes(label="Product") ), primary_key_fields=["id"], ) @@ -136,15 +157,17 @@ def store_product_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde label="Product", fields=[ cocoindex.storages.TargetFieldMapping( - source="product_id", target="id"), - ] + source="product_id", target="id" + ), + ], ), target=cocoindex.storages.NodeFromFields( label="Taxonomy", fields=[ cocoindex.storages.TargetFieldMapping( - source="taxonomy", target="value"), - ] + source="taxonomy", target="value" + ), + ], ), ), ), @@ -160,17 +183,19 @@ def store_product_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoinde label="Product", fields=[ cocoindex.storages.TargetFieldMapping( - source="product_id", target="id"), - ] + source="product_id", target="id" + ), + ], ), target=cocoindex.storages.NodeFromFields( label="Taxonomy", fields=[ cocoindex.storages.TargetFieldMapping( - source="taxonomy", target="value"), - ] + source="taxonomy", target="value" + ), + ], ), ), ), primary_key_fields=["id"], - ) \ No newline at end of file + ) diff --git a/examples/text_embedding/Text_Embedding.ipynb b/examples/text_embedding/Text_Embedding.ipynb index afd015569..5498f04c5 100644 --- a/examples/text_embedding/Text_Embedding.ipynb +++ b/examples/text_embedding/Text_Embedding.ipynb @@ -1,409 +1,408 @@ { - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "Up70lME5E0Tc" - }, - "source": [ - "# ![icon.svg](https://cocoindex.io/icon.svg) Welcome to [Cocoindex](https://cocoindex.io/)\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "bJ3LGSyF9D1M" - }, - "source": [ - "\n", - "# ![icon.svg](https://cocoindex.io/icon.svg) This example will show you how you can get started with Cocoindex by building embedding for RAG" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ymNZ0fk09noG" - }, - "source": [ - "# Install Cocoindex and other required packages using pip" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "s4MT3saT9COe" - }, - "source": [] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true, - "id": "rQcJanCi-W3I" - }, - "outputs": [], - "source": [ - "%pip install cocoindex python-dotenv psycopg[binary,pool]" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xh2sMemiA7_N" - }, - "source": [ - "# Grab some markdown files for demo" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true, - "id": "0Gi-MHrNA8sQ" - }, - "outputs": [], - "source": [ - "!mkdir -p markdown_files && \\\n", - "wget -P markdown_files https://raw.githubusercontent.com/cocoindex-io/cocoindex/refs/heads/main/examples/text_embedding/markdown_files/1706.03762v7.md && \\\n", - "wget -P markdown_files https://raw.githubusercontent.com/cocoindex-io/cocoindex/refs/heads/main/examples/text_embedding/markdown_files/1810.04805v2.md && \\\n", - "wget -P markdown_files https://raw.githubusercontent.com/cocoindex-io/cocoindex/refs/heads/main/examples/text_embedding/markdown_files/rfc8259.md\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "hPctYqRAzgEq" - }, - "outputs": [], - "source": [] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZEetEtmPAuZ-" - }, - "source": [ - "# Create a Postgres Server" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "collapsed": true, - "id": "lkATpKLUAuuY" - }, - "outputs": [], - "source": [ - "# Update package lists\n", - "!sudo apt-get update\n", - "\n", - "# Install PostgreSQL setup helper\n", - "!sudo apt install -y postgresql-common\n", - "\n", - "# Automatically press Enter for the setup script\n", - "!yes \"\" | sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh\n", - "\n", - "# Install PostgreSQL 17 and pgvector extension\n", - "!sudo apt install -y postgresql-17 postgresql-17-pgvector\n", - "\n", - "# Start PostgreSQL service\n", - "!sudo service postgresql start\n", - "\n", - "# Create user and database for cocoindex\n", - "!sudo -u postgres psql -c \"CREATE USER cocoindex WITH PASSWORD 'cocoindex';\"\n", - "!sudo -u postgres createdb cocoindex -O cocoindex\n", - "\n", - "# Enable the pgvector extension\n", - "!sudo -u postgres psql -d cocoindex -c \"CREATE EXTENSION IF NOT EXISTS vector;\"\n", - "\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "utZpExYkAzi6" - }, - "source": [ - "# Update .env with POSTGRES URL" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "X3P8pEUOA5D2" - }, - "outputs": [], - "source": [ - "%%writefile .env\n", - "COCOINDEX_DATABASE_URL=\"postgresql://cocoindex:cocoindex@localhost:5432/cocoindex\"" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "9zN612eW_1nX" - }, - "source": [ - "# Create a new file and import modules" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "7HUYtsoN-10D" - }, - "outputs": [], - "source": [ - "%%writefile main.py\n", - "from dotenv import load_dotenv\n", - "import os\n", - "from psycopg_pool import ConnectionPool\n", - "import cocoindex\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "2DOY5Q27ADS2" - }, - "source": [ - "# Define your embedding function" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "L_puYY6FABbr" - }, - "outputs": [], - "source": [ - "%%writefile -a main.py\n", - "\n", - "@cocoindex.transform_flow()\n", - "def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]:\n", - " \"\"\"\n", - " Embed the text using a SentenceTransformer model.\n", - " This is shared logic between indexing and querying.\n", - " \"\"\"\n", - " return text.transform(\n", - " cocoindex.functions.SentenceTransformerEmbed(\n", - " model=\"sentence-transformers/all-MiniLM-L6-v2\"))\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "H6j2aiRaAEKz" - }, - "source": [ - "# Define your flow" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oatJUXjAAEhE" - }, - "outputs": [], - "source": [ - "%%writefile -a main.py\n", - "\n", - "@cocoindex.flow_def(name=\"TextEmbedding\")\n", - "def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):\n", - " \"\"\"\n", - " Define a flow that embeds text into a vector database.\n", - " \"\"\"\n", - " data_scope[\"documents\"] = flow_builder.add_source(\n", - " cocoindex.sources.LocalFile(path=\"markdown_files\"))\n", - "\n", - " doc_embeddings = data_scope.add_collector()\n", - "\n", - " with data_scope[\"documents\"].row() as doc:\n", - " doc[\"chunks\"] = doc[\"content\"].transform(\n", - " cocoindex.functions.SplitRecursively(),\n", - " language=\"markdown\", chunk_size=2000, chunk_overlap=500)\n", - "\n", - " with doc[\"chunks\"].row() as chunk:\n", - " chunk[\"embedding\"] = text_to_embedding(chunk[\"text\"])\n", - " doc_embeddings.collect(filename=doc[\"filename\"], location=chunk[\"location\"],\n", - " text=chunk[\"text\"], embedding=chunk[\"embedding\"])\n", - "\n", - " doc_embeddings.export(\n", - " \"doc_embeddings\",\n", - " cocoindex.storages.Postgres(),\n", - " primary_key_fields=[\"filename\", \"location\"],\n", - " vector_indexes=[\n", - " cocoindex.VectorIndexDef(\n", - " field_name=\"embedding\",\n", - " metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)])\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "KLb41N5UAFJx" - }, - "source": [ - "\n", - "# Provide query logic\n", - "\n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tRdfIP6OAFe1" - }, - "outputs": [], - "source": [ - "%%writefile -a main.py\n", - "\n", - "def search(pool: ConnectionPool, query: str, top_k: int = 5):\n", - " # Get the table name, for the export target in the text_embedding_flow above.\n", - " table_name = cocoindex.utils.get_target_storage_default_name(text_embedding_flow, \"doc_embeddings\")\n", - " # Evaluate the transform flow defined above with the input query, to get the embedding.\n", - " query_vector = text_to_embedding.eval(query)\n", - " # Run the query and get the results.\n", - " with pool.connection() as conn:\n", - " with conn.cursor() as cur:\n", - " cur.execute(f\"\"\"\n", - " SELECT filename, text, embedding <=> %s::vector AS distance\n", - " FROM {table_name} ORDER BY distance LIMIT %s\n", - " \"\"\", (query_vector, top_k))\n", - " return [\n", - " {\"filename\": row[0], \"text\": row[1], \"score\": 1.0 - row[2]}\n", - " for row in cur.fetchall()\n", - " ]\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "IUBdoOmOAgwc" - }, - "source": [ - "# Define search function and main" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "W78hBbDiAhFh" - }, - "outputs": [], - "source": [ - "%%writefile -a main.py\n", - "\n", - "def _main():\n", - " # Initialize the database connection pool.\n", - " pool = ConnectionPool(os.getenv(\"COCOINDEX_DATABASE_URL\"))\n", - " # Run queries in a loop to demonstrate the query capabilities.\n", - " while True:\n", - " try:\n", - " query = input(\"Enter search query (or Enter to quit): \")\n", - " if query == '':\n", - " break\n", - " # Run the query function with the database connection pool and the query.\n", - " results = search(pool, query)\n", - " print(\"\\nSearch results:\")\n", - " for result in results:\n", - " print(f\"[{result['score']:.3f}] {result['filename']}\")\n", - " print(f\" {result['text']}\")\n", - " print(\"---\")\n", - " print()\n", - " except KeyboardInterrupt:\n", - " break\n", - "\n", - "if __name__ == \"__main__\":\n", - " load_dotenv(override=True)\n", - " cocoindex.init()\n", - " _main()\n" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "I2oI_pjxCkRa" - }, - "source": [ - "# Setup" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "oBStjaI0Cli_" - }, - "outputs": [], - "source": [ - "!yes yes | cocoindex setup main.py" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "aPBDVrG_CmwH" - }, - "source": [ - "# Update" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "M9g6xIZHCn5T" - }, - "outputs": [], - "source": [ - "!cocoindex update main.py" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "nIM78MBRCppz" - }, - "source": [ - "# Run query" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6E-HR_KSCqzP" - }, - "outputs": [], - "source": [ - "!python main.py" - ] - } - ], - "metadata": { - "colab": { - "provenance": [] - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "Up70lME5E0Tc" + }, + "source": [ + "# ![icon.svg](https://cocoindex.io/icon.svg) Welcome to [Cocoindex](https://cocoindex.io/)\n", + "\n" + ] }, - "nbformat": 4, - "nbformat_minor": 0 + { + "cell_type": "markdown", + "metadata": { + "id": "bJ3LGSyF9D1M" + }, + "source": [ + "\n", + "# ![icon.svg](https://cocoindex.io/icon.svg) This example will show you how you can get started with Cocoindex by building embedding for RAG" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ymNZ0fk09noG" + }, + "source": [ + "# Install Cocoindex and other required packages using pip" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "s4MT3saT9COe" + }, + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "rQcJanCi-W3I" + }, + "outputs": [], + "source": [ + "%pip install cocoindex python-dotenv psycopg[binary,pool]" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xh2sMemiA7_N" + }, + "source": [ + "# Grab some markdown files for demo" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "0Gi-MHrNA8sQ" + }, + "outputs": [], + "source": [ + "!mkdir -p markdown_files && \\\n", + "wget -P markdown_files https://raw.githubusercontent.com/cocoindex-io/cocoindex/refs/heads/main/examples/text_embedding/markdown_files/1706.03762v7.md && \\\n", + "wget -P markdown_files https://raw.githubusercontent.com/cocoindex-io/cocoindex/refs/heads/main/examples/text_embedding/markdown_files/1810.04805v2.md && \\\n", + "wget -P markdown_files https://raw.githubusercontent.com/cocoindex-io/cocoindex/refs/heads/main/examples/text_embedding/markdown_files/rfc8259.md" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "hPctYqRAzgEq" + }, + "outputs": [], + "source": [] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZEetEtmPAuZ-" + }, + "source": [ + "# Create a Postgres Server" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true, + "id": "lkATpKLUAuuY" + }, + "outputs": [], + "source": [ + "# Update package lists\n", + "!sudo apt-get update\n", + "\n", + "# Install PostgreSQL setup helper\n", + "!sudo apt install -y postgresql-common\n", + "\n", + "# Automatically press Enter for the setup script\n", + "!yes \"\" | sudo /usr/share/postgresql-common/pgdg/apt.postgresql.org.sh\n", + "\n", + "# Install PostgreSQL 17 and pgvector extension\n", + "!sudo apt install -y postgresql-17 postgresql-17-pgvector\n", + "\n", + "# Start PostgreSQL service\n", + "!sudo service postgresql start\n", + "\n", + "# Create user and database for cocoindex\n", + "!sudo -u postgres psql -c \"CREATE USER cocoindex WITH PASSWORD 'cocoindex';\"\n", + "!sudo -u postgres createdb cocoindex -O cocoindex\n", + "\n", + "# Enable the pgvector extension\n", + "!sudo -u postgres psql -d cocoindex -c \"CREATE EXTENSION IF NOT EXISTS vector;\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "utZpExYkAzi6" + }, + "source": [ + "# Update .env with POSTGRES URL" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "X3P8pEUOA5D2" + }, + "outputs": [], + "source": [ + "%%writefile .env\n", + "COCOINDEX_DATABASE_URL=\"postgresql://cocoindex:cocoindex@localhost:5432/cocoindex\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9zN612eW_1nX" + }, + "source": [ + "# Create a new file and import modules" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "7HUYtsoN-10D" + }, + "outputs": [], + "source": [ + "%%writefile main.py\n", + "from dotenv import load_dotenv\n", + "import os\n", + "from psycopg_pool import ConnectionPool\n", + "import cocoindex\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "2DOY5Q27ADS2" + }, + "source": [ + "# Define your embedding function" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "L_puYY6FABbr" + }, + "outputs": [], + "source": [ + "%%writefile -a main.py\n", + "\n", + "@cocoindex.transform_flow()\n", + "def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]:\n", + " \"\"\"\n", + " Embed the text using a SentenceTransformer model.\n", + " This is shared logic between indexing and querying.\n", + " \"\"\"\n", + " return text.transform(\n", + " cocoindex.functions.SentenceTransformerEmbed(\n", + " model=\"sentence-transformers/all-MiniLM-L6-v2\"))\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "H6j2aiRaAEKz" + }, + "source": [ + "# Define your flow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oatJUXjAAEhE" + }, + "outputs": [], + "source": [ + "%%writefile -a main.py\n", + "\n", + "@cocoindex.flow_def(name=\"TextEmbedding\")\n", + "def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope):\n", + " \"\"\"\n", + " Define a flow that embeds text into a vector database.\n", + " \"\"\"\n", + " data_scope[\"documents\"] = flow_builder.add_source(\n", + " cocoindex.sources.LocalFile(path=\"markdown_files\"))\n", + "\n", + " doc_embeddings = data_scope.add_collector()\n", + "\n", + " with data_scope[\"documents\"].row() as doc:\n", + " doc[\"chunks\"] = doc[\"content\"].transform(\n", + " cocoindex.functions.SplitRecursively(),\n", + " language=\"markdown\", chunk_size=2000, chunk_overlap=500)\n", + "\n", + " with doc[\"chunks\"].row() as chunk:\n", + " chunk[\"embedding\"] = text_to_embedding(chunk[\"text\"])\n", + " doc_embeddings.collect(filename=doc[\"filename\"], location=chunk[\"location\"],\n", + " text=chunk[\"text\"], embedding=chunk[\"embedding\"])\n", + "\n", + " doc_embeddings.export(\n", + " \"doc_embeddings\",\n", + " cocoindex.storages.Postgres(),\n", + " primary_key_fields=[\"filename\", \"location\"],\n", + " vector_indexes=[\n", + " cocoindex.VectorIndexDef(\n", + " field_name=\"embedding\",\n", + " metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)])\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "KLb41N5UAFJx" + }, + "source": [ + "\n", + "# Provide query logic\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tRdfIP6OAFe1" + }, + "outputs": [], + "source": [ + "%%writefile -a main.py\n", + "\n", + "def search(pool: ConnectionPool, query: str, top_k: int = 5):\n", + " # Get the table name, for the export target in the text_embedding_flow above.\n", + " table_name = cocoindex.utils.get_target_storage_default_name(text_embedding_flow, \"doc_embeddings\")\n", + " # Evaluate the transform flow defined above with the input query, to get the embedding.\n", + " query_vector = text_to_embedding.eval(query)\n", + " # Run the query and get the results.\n", + " with pool.connection() as conn:\n", + " with conn.cursor() as cur:\n", + " cur.execute(f\"\"\"\n", + " SELECT filename, text, embedding <=> %s::vector AS distance\n", + " FROM {table_name} ORDER BY distance LIMIT %s\n", + " \"\"\", (query_vector, top_k))\n", + " return [\n", + " {\"filename\": row[0], \"text\": row[1], \"score\": 1.0 - row[2]}\n", + " for row in cur.fetchall()\n", + " ]\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "IUBdoOmOAgwc" + }, + "source": [ + "# Define search function and main" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "W78hBbDiAhFh" + }, + "outputs": [], + "source": [ + "%%writefile -a main.py\n", + "\n", + "def _main():\n", + " # Initialize the database connection pool.\n", + " pool = ConnectionPool(os.getenv(\"COCOINDEX_DATABASE_URL\"))\n", + " # Run queries in a loop to demonstrate the query capabilities.\n", + " while True:\n", + " try:\n", + " query = input(\"Enter search query (or Enter to quit): \")\n", + " if query == '':\n", + " break\n", + " # Run the query function with the database connection pool and the query.\n", + " results = search(pool, query)\n", + " print(\"\\nSearch results:\")\n", + " for result in results:\n", + " print(f\"[{result['score']:.3f}] {result['filename']}\")\n", + " print(f\" {result['text']}\")\n", + " print(\"---\")\n", + " print()\n", + " except KeyboardInterrupt:\n", + " break\n", + "\n", + "if __name__ == \"__main__\":\n", + " load_dotenv(override=True)\n", + " cocoindex.init()\n", + " _main()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "I2oI_pjxCkRa" + }, + "source": [ + "# Setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oBStjaI0Cli_" + }, + "outputs": [], + "source": [ + "!yes yes | cocoindex setup main.py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "aPBDVrG_CmwH" + }, + "source": [ + "# Update" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "M9g6xIZHCn5T" + }, + "outputs": [], + "source": [ + "!cocoindex update main.py" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nIM78MBRCppz" + }, + "source": [ + "# Run query" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6E-HR_KSCqzP" + }, + "outputs": [], + "source": [ + "!python main.py" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 } diff --git a/examples/text_embedding/main.py b/examples/text_embedding/main.py index 5bd5e008b..e2684d772 100644 --- a/examples/text_embedding/main.py +++ b/examples/text_embedding/main.py @@ -3,35 +3,51 @@ import cocoindex import os + @cocoindex.transform_flow() -def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]: +def text_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[list[float]]: """ Embed the text using a SentenceTransformer model. This is a shared logic between indexing and querying, so extract it as a function. """ return text.transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) + @cocoindex.flow_def(name="TextEmbedding") -def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope): +def text_embedding_flow( + flow_builder: cocoindex.FlowBuilder, data_scope: cocoindex.DataScope +): """ Define an example flow that embeds text into a vector database. """ data_scope["documents"] = flow_builder.add_source( - cocoindex.sources.LocalFile(path="markdown_files")) + cocoindex.sources.LocalFile(path="markdown_files") + ) doc_embeddings = data_scope.add_collector() with data_scope["documents"].row() as doc: doc["chunks"] = doc["content"].transform( cocoindex.functions.SplitRecursively(), - language="markdown", chunk_size=2000, chunk_overlap=500) + language="markdown", + chunk_size=2000, + chunk_overlap=500, + ) with doc["chunks"].row() as chunk: chunk["embedding"] = text_to_embedding(chunk["text"]) - doc_embeddings.collect(filename=doc["filename"], location=chunk["location"], - text=chunk["text"], embedding=chunk["embedding"]) + doc_embeddings.collect( + filename=doc["filename"], + location=chunk["location"], + text=chunk["text"], + embedding=chunk["embedding"], + ) doc_embeddings.export( "doc_embeddings", @@ -40,33 +56,42 @@ def text_embedding_flow(flow_builder: cocoindex.FlowBuilder, data_scope: cocoind vector_indexes=[ cocoindex.VectorIndexDef( field_name="embedding", - metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY)]) + metric=cocoindex.VectorSimilarityMetric.COSINE_SIMILARITY, + ) + ], + ) def search(pool: ConnectionPool, query: str, top_k: int = 5): # Get the table name, for the export target in the text_embedding_flow above. - table_name = cocoindex.utils.get_target_storage_default_name(text_embedding_flow, "doc_embeddings") + table_name = cocoindex.utils.get_target_storage_default_name( + text_embedding_flow, "doc_embeddings" + ) # Evaluate the transform flow defined above with the input query, to get the embedding. query_vector = text_to_embedding.eval(query) # Run the query and get the results. with pool.connection() as conn: with conn.cursor() as cur: - cur.execute(f""" + cur.execute( + f""" SELECT filename, text, embedding <=> %s::vector AS distance FROM {table_name} ORDER BY distance LIMIT %s - """, (query_vector, top_k)) + """, + (query_vector, top_k), + ) return [ {"filename": row[0], "text": row[1], "score": 1.0 - row[2]} for row in cur.fetchall() ] + def _main(): # Initialize the database connection pool. pool = ConnectionPool(os.getenv("COCOINDEX_DATABASE_URL")) # Run queries in a loop to demonstrate the query capabilities. while True: query = input("Enter search query (or Enter to quit): ") - if query == '': + if query == "": break # Run the query function with the database connection pool and the query. results = search(pool, query) @@ -77,6 +102,7 @@ def _main(): print("---") print() + if __name__ == "__main__": load_dotenv() cocoindex.init() diff --git a/examples/text_embedding_qdrant/main.py b/examples/text_embedding_qdrant/main.py index 53a61bedf..f8eabb72d 100644 --- a/examples/text_embedding_qdrant/main.py +++ b/examples/text_embedding_qdrant/main.py @@ -9,14 +9,18 @@ @cocoindex.transform_flow() -def text_to_embedding(text: cocoindex.DataSlice[str]) -> cocoindex.DataSlice[list[float]]: +def text_to_embedding( + text: cocoindex.DataSlice[str], +) -> cocoindex.DataSlice[list[float]]: """ Embed the text using a SentenceTransformer model. This is a shared logic between indexing and querying, so extract it as a function. """ return text.transform( cocoindex.functions.SentenceTransformerEmbed( - model="sentence-transformers/all-MiniLM-L6-v2")) + model="sentence-transformers/all-MiniLM-L6-v2" + ) + ) @cocoindex.flow_def(name="TextEmbeddingWithQdrant") @@ -60,23 +64,24 @@ def text_embedding_flow( setup_by_user=True, ) + def _main(): # Initialize Qdrant client client = QdrantClient(url=QDRANT_GRPC_URL, prefer_grpc=True) - + # Run queries in a loop to demonstrate the query capabilities. while True: query = input("Enter search query (or Enter to quit): ") if query == "": break - + # Get the embedding for the query query_embedding = text_to_embedding.eval(query) - + search_results = client.search( collection_name=QDRANT_COLLECTION, query_vector=("text_embedding", query_embedding), - limit=10 + limit=10, ) print("\nSearch results:") for result in search_results: diff --git a/pyproject.toml b/pyproject.toml index c7e5977fd..b7fb7eb3c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ features = ["pyo3/extension-module"] [project.optional-dependencies] test = ["pytest"] +dev = ["ruff"] [tool.mypy] python_version = "3.11" diff --git a/python/cocoindex/__init__.py b/python/cocoindex/__init__.py index e7c1fd1df..2b252299a 100644 --- a/python/cocoindex/__init__.py +++ b/python/cocoindex/__init__.py @@ -1,6 +1,7 @@ """ Cocoindex is a framework for building and running indexing pipelines. """ + from . import functions, query, sources, storages, cli, utils from .auth_registry import AuthEntryReference, add_auth_entry, ref_auth_entry @@ -13,4 +14,4 @@ from .index import VectorSimilarityMetric, VectorIndexDef, IndexOptions from .setting import DatabaseConnectionSpec, Settings, ServerSettings from .setting import get_app_namespace -from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json \ No newline at end of file +from .typing import Float32, Float64, LocalDateTime, OffsetDateTime, Range, Vector, Json diff --git a/python/cocoindex/auth_registry.py b/python/cocoindex/auth_registry.py index 9b1a20738..7a91a71ba 100644 --- a/python/cocoindex/auth_registry.py +++ b/python/cocoindex/auth_registry.py @@ -10,16 +10,20 @@ T = TypeVar("T") + @dataclass class AuthEntryReference(Generic[T]): """Reference an auth entry by its key.""" + key: str + def add_auth_entry(key: str, value: T) -> AuthEntryReference[T]: """Add an auth entry to the registry. Returns its reference.""" _engine.add_auth_entry(key, dump_engine_object(value)) return AuthEntryReference(key) + def ref_auth_entry(key: str) -> AuthEntryReference: """Reference an auth entry by its key.""" - return AuthEntryReference(key) \ No newline at end of file + return AuthEntryReference(key) diff --git a/python/cocoindex/cli.py b/python/cocoindex/cli.py index 01ce9f09b..1a861266f 100644 --- a/python/cocoindex/cli.py +++ b/python/cocoindex/cli.py @@ -14,7 +14,8 @@ from .setup import sync_setup, drop_setup, flow_names_with_setup, apply_setup_changes # Create ServerSettings lazily upon first call, as environment variables may be loaded from files, etc. -COCOINDEX_HOST = 'https://cocoindex.io' +COCOINDEX_HOST = "https://cocoindex.io" + def _parse_app_flow_specifier(specifier: str) -> tuple[str, str | None]: """Parses 'module_or_path[:flow_name]' into (module_or_path, flow_name | None).""" @@ -25,7 +26,7 @@ def _parse_app_flow_specifier(specifier: str) -> tuple[str, str | None]: raise click.BadParameter( f"Application module/path part is missing or invalid in specifier: '{specifier}'. " "Expected format like 'myapp.py' or 'myapp:MyFlow'.", - param_hint="APP_SPECIFIER" + param_hint="APP_SPECIFIER", ) if len(parts) == 1: @@ -33,7 +34,7 @@ def _parse_app_flow_specifier(specifier: str) -> tuple[str, str | None]: flow_ref_part = parts[1] - if not flow_ref_part: # Handles empty string after colon + if not flow_ref_part: # Handles empty string after colon return app_ref, None if not flow_ref_part.isidentifier(): @@ -41,10 +42,11 @@ def _parse_app_flow_specifier(specifier: str) -> tuple[str, str | None]: f"Invalid format for flow name part ('{flow_ref_part}') in specifier '{specifier}'. " "If a colon separates the application from the flow name, the flow name should typically be " "a valid identifier (e.g., alphanumeric with underscores, not starting with a number).", - param_hint="APP_SPECIFIER" + param_hint="APP_SPECIFIER", ) return app_ref, flow_ref_part + def _get_app_ref_from_specifier( specifier: str, ) -> str: @@ -59,12 +61,13 @@ def _get_app_ref_from_specifier( click.style( f"Ignoring flow name '{flow_ref}' in '{specifier}': " f"this command operates on the entire app/module '{app_ref}'.", - fg='yellow' + fg="yellow", ), - err=True + err=True, ) return app_ref + def _load_user_app(app_target: str) -> types.ModuleType: """ Loads the user's application, which can be a file path or an installed module name. @@ -81,7 +84,7 @@ def _load_user_app(app_target: str) -> types.ModuleType: app_path = os.path.abspath(app_target) app_dir = os.path.dirname(app_path) module_name = os.path.splitext(os.path.basename(app_path))[0] - + if app_dir not in sys.path: sys.path.insert(0, app_dir) try: @@ -89,7 +92,7 @@ def _load_user_app(app_target: str) -> types.ModuleType: if spec is None: raise ImportError(f"Could not create spec for file: {app_path}") module = importlib.util.module_from_spec(spec) - sys.modules[spec.name] = module + sys.modules[spec.name] = module if spec.loader is None: raise ImportError(f"Could not create loader for file: {app_path}") spec.loader.exec_module(module) @@ -106,17 +109,22 @@ def _load_user_app(app_target: str) -> types.ModuleType: except ImportError as e: raise click.ClickException(f"Failed to load module '{app_target}': {e}") except Exception as e: - raise click.ClickException(f"Unexpected error importing module '{app_target}': {e}") + raise click.ClickException( + f"Unexpected error importing module '{app_target}': {e}" + ) + @click.group() @click.version_option(package_name="cocoindex", message="%(prog)s version %(version)s") @click.option( "--env-file", - type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True), + type=click.Path( + exists=True, file_okay=True, dir_okay=False, readable=True, resolve_path=True + ), help="Path to a .env file to load environment variables from. " - "If not provided, attempts to load '.env' from the current directory.", + "If not provided, attempts to load '.env' from the current directory.", default=None, - show_default=False + show_default=False, ) def cli(env_file: str | None): """ @@ -135,6 +143,7 @@ def cli(env_file: str | None): except Exception as e: raise click.ClickException(f"Failed to initialize CocoIndex library: {e}") + @cli.command() @click.argument("app_target", type=str, required=False) def ls(app_target: str | None): @@ -168,9 +177,11 @@ def ls(app_target: str | None): has_missing = True if has_missing: - click.echo('') - click.echo('Notes:') - click.echo(' [+]: Flows present in the current process, but missing setup.') + click.echo("") + click.echo("Notes:") + click.echo( + " [+]: Flows present in the current process, but missing setup." + ) else: if not persisted_flow_names: @@ -180,9 +191,12 @@ def ls(app_target: str | None): for name in sorted(persisted_flow_names): click.echo(name) + @cli.command() @click.argument("app_flow_specifier", type=str) -@click.option("--color/--no-color", default=True, help="Enable or disable colored output.") +@click.option( + "--color/--no-color", default=True, help="Enable or disable colored output." +) @click.option("--verbose", is_flag=True, help="Show verbose output with full details.") def show(app_flow_specifier: str, color: bool, verbose: bool): """ @@ -209,7 +223,7 @@ def show(app_flow_specifier: str, color: bool, verbose: bool): table = Table( title=f"Schema for Flow: {fl.name}", title_style="cyan", - header_style="bold magenta" + header_style="bold magenta", ) table.add_column("Field", style="cyan") table.add_column("Type", style="green") @@ -218,6 +232,7 @@ def show(app_flow_specifier: str, color: bool, verbose: bool): table.add_row(field_name, field_type, attr_str) console.print(table) + @cli.command() @click.argument("app_target", type=str) def setup(app_target: str): @@ -236,18 +251,28 @@ def setup(app_target: str): click.echo("No changes need to be pushed.") return if not click.confirm( - "Changes need to be pushed. Continue? [yes/N]", default=False, show_default=False): + "Changes need to be pushed. Continue? [yes/N]", + default=False, + show_default=False, + ): return apply_setup_changes(setup_status) + @cli.command("drop") @click.argument("app_target", type=str, required=False) @click.argument("flow_name", type=str, nargs=-1) @click.option( - "-a", "--all", "drop_all", is_flag=True, show_default=True, default=False, + "-a", + "--all", + "drop_all", + is_flag=True, + show_default=True, + default=False, help="Drop the backend setup for all flows with persisted setup, " - "even if not defined in the current process." - "If used, APP_TARGET and any listed flow names are ignored.") + "even if not defined in the current process." + "If used, APP_TARGET and any listed flow names are ignored.", +) def drop(app_target: str | None, flow_name: tuple[str, ...], drop_all: bool): """ Drop the backend setup for flows. @@ -263,20 +288,29 @@ def drop(app_target: str | None, flow_name: tuple[str, ...], drop_all: bool): if drop_all: if app_target or flow_name: - click.echo("Warning: When --all is used, APP_TARGET and any individual flow names are ignored.", err=True) + click.echo( + "Warning: When --all is used, APP_TARGET and any individual flow names are ignored.", + err=True, + ) flow_names = flow_names_with_setup() elif app_target: app_ref = _get_app_ref_from_specifier(app_target) _load_user_app(app_ref) if flow_name: flow_names = list(flow_name) - click.echo(f"Preparing to drop specified flows: {', '.join(flow_names)} (in '{app_ref}').", err=True) + click.echo( + f"Preparing to drop specified flows: {', '.join(flow_names)} (in '{app_ref}').", + err=True, + ) else: flow_names = flow.flow_names() if not flow_names: click.echo(f"No flows found defined in '{app_ref}' to drop.") return - click.echo(f"Preparing to drop all flows defined in '{app_ref}': {', '.join(flow_names)}.", err=True) + click.echo( + f"Preparing to drop all flows defined in '{app_ref}': {', '.join(flow_names)}.", + err=True, + ) else: raise click.UsageError( "Missing arguments. You must either provide an APP_TARGET (to target app-specific flows) " @@ -294,19 +328,32 @@ def drop(app_target: str | None, flow_name: tuple[str, ...], drop_all: bool): return if not click.confirm( f"\nThis will apply changes to drop setup for: {', '.join(flow_names)}. Continue? [yes/N]", - default=False, show_default=False): + default=False, + show_default=False, + ): click.echo("Drop operation aborted by user.") return apply_setup_changes(setup_status) + @cli.command() @click.argument("app_flow_specifier", type=str) @click.option( - "-L", "--live", is_flag=True, show_default=True, default=False, - help="Continuously watch changes from data sources and apply to the target index.") + "-L", + "--live", + is_flag=True, + show_default=True, + default=False, + help="Continuously watch changes from data sources and apply to the target index.", +) @click.option( - "-q", "--quiet", is_flag=True, show_default=True, default=False, - help="Avoid printing anything to the standard output, e.g. statistics.") + "-q", + "--quiet", + is_flag=True, + show_default=True, + default=False, + help="Avoid printing anything to the standard output, e.g. statistics.", +) def update(app_flow_specifier: str, live: bool, quiet: bool): """ Update the index to reflect the latest data from data sources. @@ -325,14 +372,23 @@ def update(app_flow_specifier: str, live: bool, quiet: bool): updater.wait() return updater.update_stats() + @cli.command() @click.argument("app_flow_specifier", type=str) @click.option( - "-o", "--output-dir", type=str, required=False, - help="The directory to dump the output to.") + "-o", + "--output-dir", + type=str, + required=False, + help="The directory to dump the output to.", +) @click.option( - "--cache/--no-cache", is_flag=True, show_default=True, default=True, - help="Use already-cached intermediate data if available.") + "--cache/--no-cache", + is_flag=True, + show_default=True, + default=True, + help="Use already-cached intermediate data if available.", +) def evaluate(app_flow_specifier: str, output_dir: str | None, cache: bool = True): """ Evaluate the flow and dump flow outputs to files. @@ -359,32 +415,64 @@ def evaluate(app_flow_specifier: str, output_dir: str | None, cache: bool = True options = flow.EvaluateAndDumpOptions(output_dir=output_dir, use_cache=cache) fl.evaluate_and_dump(options) + @cli.command() @click.argument("app_target", type=str) @click.option( - "-a", "--address", type=str, + "-a", + "--address", + type=str, help="The address to bind the server to, in the format of IP:PORT. " - "If unspecified, the address specified in COCOINDEX_SERVER_ADDRESS will be used.") + "If unspecified, the address specified in COCOINDEX_SERVER_ADDRESS will be used.", +) @click.option( - "-c", "--cors-origin", type=str, + "-c", + "--cors-origin", + type=str, help="The origins of the clients (e.g. CocoInsight UI) to allow CORS from. " - "Multiple origins can be specified as a comma-separated list. " - "e.g. `https://cocoindex.io,http://localhost:3000`. " - "Origins specified in COCOINDEX_SERVER_CORS_ORIGINS will also be included.") + "Multiple origins can be specified as a comma-separated list. " + "e.g. `https://cocoindex.io,http://localhost:3000`. " + "Origins specified in COCOINDEX_SERVER_CORS_ORIGINS will also be included.", +) @click.option( - "-ci", "--cors-cocoindex", is_flag=True, show_default=True, default=False, - help=f"Allow {COCOINDEX_HOST} to access the server.") + "-ci", + "--cors-cocoindex", + is_flag=True, + show_default=True, + default=False, + help=f"Allow {COCOINDEX_HOST} to access the server.", +) @click.option( - "-cl", "--cors-local", type=int, - help="Allow http://localhost: to access the server.") + "-cl", + "--cors-local", + type=int, + help="Allow http://localhost: to access the server.", +) @click.option( - "-L", "--live-update", is_flag=True, show_default=True, default=False, - help="Continuously watch changes from data sources and apply to the target index.") + "-L", + "--live-update", + is_flag=True, + show_default=True, + default=False, + help="Continuously watch changes from data sources and apply to the target index.", +) @click.option( - "-q", "--quiet", is_flag=True, show_default=True, default=False, - help="Avoid printing anything to the standard output, e.g. statistics.") -def server(app_target: str, address: str | None, live_update: bool, quiet: bool, - cors_origin: str | None, cors_cocoindex: bool, cors_local: int | None): + "-q", + "--quiet", + is_flag=True, + show_default=True, + default=False, + help="Avoid printing anything to the standard output, e.g. statistics.", +) +def server( + app_target: str, + address: str | None, + live_update: bool, + quiet: bool, + cors_origin: str | None, + cors_cocoindex: bool, + cors_local: int | None, +): """ Start a HTTP server providing REST APIs. @@ -421,20 +509,26 @@ def server(app_target: str, address: str | None, live_update: bool, quiet: bool, def _flow_name(name: str | None) -> str: names = flow.flow_names() - available = ', '.join(sorted(names)) + available = ", ".join(sorted(names)) if name is not None: if name not in names: - raise click.BadParameter(f"Flow '{name}' not found.\nAvailable: {available if names else 'None'}") + raise click.BadParameter( + f"Flow '{name}' not found.\nAvailable: {available if names else 'None'}" + ) return name if len(names) == 0: raise click.UsageError("No flows available in the loaded application.") elif len(names) == 1: return names[0] else: - raise click.UsageError(f"Multiple flows available, please specify which flow to target by appending :FlowName to the APP_TARGET.\nAvailable: {available}") + raise click.UsageError( + f"Multiple flows available, please specify which flow to target by appending :FlowName to the APP_TARGET.\nAvailable: {available}" + ) + def _flow_by_name(name: str | None) -> flow.Flow: return flow.flow_by_name(_flow_name(name)) + if __name__ == "__main__": - cli() \ No newline at end of file + cli() diff --git a/python/cocoindex/convert.py b/python/cocoindex/convert.py index 54817e31a..da9e953bf 100644 --- a/python/cocoindex/convert.py +++ b/python/cocoindex/convert.py @@ -1,6 +1,7 @@ """ Utilities to convert between Python and engine values. """ + import dataclasses import datetime import inspect @@ -8,28 +9,40 @@ from enum import Enum from typing import Any, Callable, get_origin, Mapping -from .typing import analyze_type_info, encode_enriched_type, is_namedtuple_type, TABLE_TYPES, KEY_FIELD_NAME +from .typing import ( + analyze_type_info, + encode_enriched_type, + is_namedtuple_type, + TABLE_TYPES, + KEY_FIELD_NAME, +) def encode_engine_value(value: Any) -> Any: """Encode a Python value to an engine value.""" if dataclasses.is_dataclass(value): - return [encode_engine_value(getattr(value, f.name)) for f in dataclasses.fields(value)] + return [ + encode_engine_value(getattr(value, f.name)) + for f in dataclasses.fields(value) + ] if is_namedtuple_type(type(value)): return [encode_engine_value(getattr(value, name)) for name in value._fields] if isinstance(value, (list, tuple)): return [encode_engine_value(v) for v in value] if isinstance(value, dict): - return [[encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items()] + return [ + [encode_engine_value(k)] + encode_engine_value(v) for k, v in value.items() + ] if isinstance(value, uuid.UUID): return value.bytes return value + def make_engine_value_decoder( - field_path: list[str], - src_type: dict[str, Any], - dst_annotation, - ) -> Callable[[Any], Any]: + field_path: list[str], + src_type: dict[str, Any], + dst_annotation, +) -> Callable[[Any], Any]: """ Make a decoder from an engine value to a Python value. @@ -42,12 +55,18 @@ def make_engine_value_decoder( A decoder from an engine value to a Python value. """ - src_type_kind = src_type['kind'] + src_type_kind = src_type["kind"] - if dst_annotation is None or dst_annotation is inspect.Parameter.empty or dst_annotation is Any: - if src_type_kind == 'Struct' or src_type_kind in TABLE_TYPES: - raise ValueError(f"Missing type annotation for `{''.join(field_path)}`." - f"It's required for {src_type_kind} type.") + if ( + dst_annotation is None + or dst_annotation is inspect.Parameter.empty + or dst_annotation is Any + ): + if src_type_kind == "Struct" or src_type_kind in TABLE_TYPES: + raise ValueError( + f"Missing type annotation for `{''.join(field_path)}`." + f"It's required for {src_type_kind} type." + ) return lambda value: value dst_type_info = analyze_type_info(dst_annotation) @@ -55,95 +74,118 @@ def make_engine_value_decoder( if src_type_kind != dst_type_info.kind: raise ValueError( f"Type mismatch for `{''.join(field_path)}`: " - f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})") + f"passed in {src_type_kind}, declared {dst_annotation} ({dst_type_info.kind})" + ) if dst_type_info.struct_type is not None: return _make_engine_struct_value_decoder( - field_path, src_type['fields'], dst_type_info.struct_type) + field_path, src_type["fields"], dst_type_info.struct_type + ) if src_type_kind in TABLE_TYPES: - field_path.append('[*]') + field_path.append("[*]") elem_type_info = analyze_type_info(dst_type_info.elem_type) if elem_type_info.struct_type is None: - raise ValueError(f"Type mismatch for `{''.join(field_path)}`: " - f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected") - engine_fields_schema = src_type['row']['fields'] + raise ValueError( + f"Type mismatch for `{''.join(field_path)}`: " + f"declared `{dst_type_info.kind}`, a dataclass or NamedTuple type expected" + ) + engine_fields_schema = src_type["row"]["fields"] if elem_type_info.key_type is not None: key_field_schema = engine_fields_schema[0] field_path.append(f".{key_field_schema.get('name', KEY_FIELD_NAME)}") key_decoder = make_engine_value_decoder( - field_path, key_field_schema['type'], elem_type_info.key_type) + field_path, key_field_schema["type"], elem_type_info.key_type + ) field_path.pop() value_decoder = _make_engine_struct_value_decoder( - field_path, engine_fields_schema[1:], elem_type_info.struct_type) + field_path, engine_fields_schema[1:], elem_type_info.struct_type + ) + def decode(value): if value is None: return None return {key_decoder(v[0]): value_decoder(v[1:]) for v in value} else: elem_decoder = _make_engine_struct_value_decoder( - field_path, engine_fields_schema, elem_type_info.struct_type) + field_path, engine_fields_schema, elem_type_info.struct_type + ) + def decode(value): if value is None: return None return [elem_decoder(v) for v in value] + field_path.pop() return decode - if src_type_kind == 'Uuid': + if src_type_kind == "Uuid": return lambda value: uuid.UUID(bytes=value) return lambda value: value + def _make_engine_struct_value_decoder( - field_path: list[str], - src_fields: list[dict[str, Any]], - dst_struct_type: type, - ) -> Callable[[list], Any]: + field_path: list[str], + src_fields: list[dict[str, Any]], + dst_struct_type: type, +) -> Callable[[list], Any]: """Make a decoder from an engine field values to a Python value.""" - src_name_to_idx = {f['name']: i for i, f in enumerate(src_fields)} - + src_name_to_idx = {f["name"]: i for i, f in enumerate(src_fields)} + parameters: Mapping[str, inspect.Parameter] if dataclasses.is_dataclass(dst_struct_type): parameters = inspect.signature(dst_struct_type).parameters elif is_namedtuple_type(dst_struct_type): - defaults = getattr(dst_struct_type, '_field_defaults', {}) - fields = getattr(dst_struct_type, '_fields', ()) + defaults = getattr(dst_struct_type, "_field_defaults", {}) + fields = getattr(dst_struct_type, "_fields", ()) parameters = { name: inspect.Parameter( name=name, kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, default=defaults.get(name, inspect.Parameter.empty), - annotation=dst_struct_type.__annotations__.get(name, inspect.Parameter.empty) + annotation=dst_struct_type.__annotations__.get( + name, inspect.Parameter.empty + ), ) for name in fields } else: raise ValueError(f"Unsupported struct type: {dst_struct_type}") - def make_closure_for_value(name: str, param: inspect.Parameter) -> Callable[[list], Any]: + def make_closure_for_value( + name: str, param: inspect.Parameter + ) -> Callable[[list], Any]: src_idx = src_name_to_idx.get(name) if src_idx is not None: - field_path.append(f'.{name}') + field_path.append(f".{name}") field_decoder = make_engine_value_decoder( - field_path, src_fields[src_idx]['type'], param.annotation) + field_path, src_fields[src_idx]["type"], param.annotation + ) field_path.pop() - return lambda values: field_decoder(values[src_idx]) if len(values) > src_idx else param.default + return ( + lambda values: field_decoder(values[src_idx]) + if len(values) > src_idx + else param.default + ) default_value = param.default if default_value is inspect.Parameter.empty: raise ValueError( - f"Field without default value is missing in input: {''.join(field_path)}") + f"Field without default value is missing in input: {''.join(field_path)}" + ) return lambda _: default_value field_value_decoder = [ - make_closure_for_value(name, param) - for (name, param) in parameters.items()] + make_closure_for_value(name, param) for (name, param) in parameters.items() + ] return lambda values: dst_struct_type( - *(decoder(values) for decoder in field_value_decoder)) + *(decoder(values) for decoder in field_value_decoder) + ) + def dump_engine_object(v: Any) -> Any: """Recursively dump an object for engine. Engine side uses `Pythonized` to catch.""" @@ -157,14 +199,14 @@ def dump_engine_object(v: Any) -> Any: total_secs = v.total_seconds() secs = int(total_secs) nanos = int((total_secs - secs) * 1e9) - return {'secs': secs, 'nanos': nanos} - elif hasattr(v, '__dict__'): + return {"secs": secs, "nanos": nanos} + elif hasattr(v, "__dict__"): s = {k: dump_engine_object(v) for k, v in v.__dict__.items()} - if hasattr(v, 'kind') and 'kind' not in s: - s['kind'] = v.kind + if hasattr(v, "kind") and "kind" not in s: + s["kind"] = v.kind return s elif isinstance(v, (list, tuple)): return [dump_engine_object(item) for item in v] elif isinstance(v, dict): return {k: dump_engine_object(v) for k, v in v.items()} - return v \ No newline at end of file + return v diff --git a/python/cocoindex/flow.py b/python/cocoindex/flow.py index d02735a1e..d5e2122fc 100644 --- a/python/cocoindex/flow.py +++ b/python/cocoindex/flow.py @@ -10,7 +10,17 @@ import datetime import functools -from typing import Any, Callable, Sequence, TypeVar, Generic, get_args, get_origin, Type, NamedTuple +from typing import ( + Any, + Callable, + Sequence, + TypeVar, + Generic, + get_args, + get_origin, + Type, + NamedTuple, +) from threading import Lock from enum import Enum from dataclasses import dataclass @@ -25,6 +35,7 @@ from .typing import encode_enriched_type from .runtime import execution_context + class _NameBuilder: _existing_names: set[str] _next_name_index: dict[str, int] @@ -51,19 +62,27 @@ def build_name(self, name: str | None, /, prefix: str) -> str: return name -_WORD_BOUNDARY_RE = re.compile('(? str: - return _WORD_BOUNDARY_RE.sub('_', name).lower() + return _WORD_BOUNDARY_RE.sub("_", name).lower() + def _create_data_slice( - flow_builder_state: _FlowBuilderState, - creator: Callable[[_engine.DataScopeRef | None, str | None], _engine.DataSlice], - name: str | None = None) -> DataSlice: + flow_builder_state: _FlowBuilderState, + creator: Callable[[_engine.DataScopeRef | None, str | None], _engine.DataSlice], + name: str | None = None, +) -> DataSlice: if name is None: - return DataSlice(_DataSliceState( - flow_builder_state, - lambda target: - creator(target[0], target[1]) if target is not None else creator(None, None))) + return DataSlice( + _DataSliceState( + flow_builder_state, + lambda target: creator(target[0], target[1]) + if target is not None + else creator(None, None), + ) + ) else: return DataSlice(_DataSliceState(flow_builder_state, creator(None, name))) @@ -71,20 +90,25 @@ def _create_data_slice( def _spec_kind(spec: Any) -> str: return spec.__class__.__name__ -T = TypeVar('T') + +T = TypeVar("T") + class _DataSliceState: flow_builder_state: _FlowBuilderState _lazy_lock: Lock | None = None # None means it's not lazy. _data_slice: _engine.DataSlice | None = None - _data_slice_creator: Callable[[tuple[_engine.DataScopeRef, str] | None], - _engine.DataSlice] | None = None + _data_slice_creator: ( + Callable[[tuple[_engine.DataScopeRef, str] | None], _engine.DataSlice] | None + ) = None def __init__( - self, flow_builder_state: _FlowBuilderState, - data_slice: _engine.DataSlice | Callable[[tuple[_engine.DataScopeRef, str] | None], - _engine.DataSlice]): + self, + flow_builder_state: _FlowBuilderState, + data_slice: _engine.DataSlice + | Callable[[tuple[_engine.DataScopeRef, str] | None], _engine.DataSlice], + ): self.flow_builder_state = flow_builder_state if isinstance(data_slice, _engine.DataSlice): @@ -124,6 +148,7 @@ def attach_to_scope(self, scope: _engine.DataScopeRef, field_name: str) -> None: # TODO: We'll support this by an identity transformer or "aliasing" in the future. raise ValueError("DataSlice is already attached to a field") + class DataSlice(Generic[T]): """A data slice represents a slice of data in a flow. It's readonly.""" @@ -167,22 +192,27 @@ def transform(self, fn_spec: op.FunctionSpec, *args, **kwargs) -> DataSlice: transform_args: list[tuple[Any, str | None]] transform_args = [(self._state.engine_data_slice, None)] - transform_args += [(self._state.flow_builder_state.get_data_slice(v), None) for v in args] - transform_args += [(self._state.flow_builder_state.get_data_slice(v), k) - for (k, v) in kwargs.items()] + transform_args += [ + (self._state.flow_builder_state.get_data_slice(v), None) for v in args + ] + transform_args += [ + (self._state.flow_builder_state.get_data_slice(v), k) + for (k, v) in kwargs.items() + ] flow_builder_state = self._state.flow_builder_state return _create_data_slice( flow_builder_state, - lambda target_scope, name: - flow_builder_state.engine_flow_builder.transform( - _spec_kind(fn_spec), - dump_engine_object(fn_spec), - transform_args, - target_scope, - flow_builder_state.field_name_builder.build_name( - name, prefix=_to_snake_case(_spec_kind(fn_spec))+'_'), - )) + lambda target_scope, name: flow_builder_state.engine_flow_builder.transform( + _spec_kind(fn_spec), + dump_engine_object(fn_spec), + transform_args, + target_scope, + flow_builder_state.field_name_builder.build_name( + name, prefix=_to_snake_case(_spec_kind(fn_spec)) + "_" + ), + ), + ) def call(self, func: Callable[[DataSlice], T], *args, **kwargs) -> T: """ @@ -190,18 +220,23 @@ def call(self, func: Callable[[DataSlice], T], *args, **kwargs) -> T: """ return func(self, *args, **kwargs) + def _data_slice_state(data_slice: DataSlice) -> _DataSliceState: return data_slice._state # pylint: disable=protected-access + class DataScope: """ A data scope in a flow. It has multple fields and collectors, and allow users to add new fields and collectors. """ + _flow_builder_state: _FlowBuilderState _engine_data_scope: _engine.DataScopeRef - def __init__(self, flow_builder_state: _FlowBuilderState, data_scope: _engine.DataScopeRef): + def __init__( + self, flow_builder_state: _FlowBuilderState, data_scope: _engine.DataScopeRef + ): self._flow_builder_state = flow_builder_state self._engine_data_scope = data_scope @@ -212,10 +247,14 @@ def __repr__(self): return repr(self._engine_data_scope) def __getitem__(self, field_name: str) -> DataSlice: - return DataSlice(_DataSliceState( - self._flow_builder_state, - self._flow_builder_state.engine_flow_builder.scope_field( - self._engine_data_scope, field_name))) + return DataSlice( + _DataSliceState( + self._flow_builder_state, + self._flow_builder_state.engine_flow_builder.scope_field( + self._engine_data_scope, field_name + ), + ) + ) def __setitem__(self, field_name: str, value: DataSlice): value._state.attach_to_scope(self._engine_data_scope, field_name) @@ -233,23 +272,32 @@ def add_collector(self, name: str | None = None) -> DataCollector: return DataCollector( self._flow_builder_state, self._engine_data_scope.add_collector( - self._flow_builder_state.field_name_builder.build_name(name, prefix="_collector_") - ) + self._flow_builder_state.field_name_builder.build_name( + name, prefix="_collector_" + ) + ), ) + class GeneratedField(Enum): """ A generated field is automatically set by the engine. """ + UUID = "Uuid" + class DataCollector: """A data collector is used to collect data into a collector.""" + _flow_builder_state: _FlowBuilderState _engine_data_collector: _engine.DataCollector - def __init__(self, flow_builder_state: _FlowBuilderState, - data_collector: _engine.DataCollector): + def __init__( + self, + flow_builder_state: _FlowBuilderState, + data_collector: _engine.DataCollector, + ): self._flow_builder_state = flow_builder_state self._engine_data_collector = data_collector @@ -268,45 +316,62 @@ def collect(self, **kwargs): else: raise ValueError(f"Unexpected generated field: {v}") else: - regular_kwargs.append( - (k, self._flow_builder_state.get_data_slice(v))) + regular_kwargs.append((k, self._flow_builder_state.get_data_slice(v))) self._flow_builder_state.engine_flow_builder.collect( - self._engine_data_collector, regular_kwargs, auto_uuid_field) + self._engine_data_collector, regular_kwargs, auto_uuid_field + ) - def export(self, name: str, target_spec: op.StorageSpec, /, *, - primary_key_fields: Sequence[str], - vector_indexes: Sequence[index.VectorIndexDef] = (), - vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (), - setup_by_user: bool = False): + def export( + self, + name: str, + target_spec: op.StorageSpec, + /, + *, + primary_key_fields: Sequence[str], + vector_indexes: Sequence[index.VectorIndexDef] = (), + vector_index: Sequence[tuple[str, index.VectorSimilarityMetric]] = (), + setup_by_user: bool = False, + ): """ Export the collected data to the specified target. `vector_index` is for backward compatibility only. Please use `vector_indexes` instead. """ if not isinstance(target_spec, op.StorageSpec): - raise ValueError("export() can only be called on a CocoIndex target storage") + raise ValueError( + "export() can only be called on a CocoIndex target storage" + ) # For backward compatibility only. if len(vector_indexes) == 0 and len(vector_index) > 0: - vector_indexes = [index.VectorIndexDef(field_name=field_name, metric=metric) - for field_name, metric in vector_index] + vector_indexes = [ + index.VectorIndexDef(field_name=field_name, metric=metric) + for field_name, metric in vector_index + ] index_options = index.IndexOptions( primary_key_fields=primary_key_fields, vector_indexes=vector_indexes, ) self._flow_builder_state.engine_flow_builder.export( - name, _spec_kind(target_spec), dump_engine_object(target_spec), - dump_engine_object(index_options), self._engine_data_collector, setup_by_user) + name, + _spec_kind(target_spec), + dump_engine_object(target_spec), + dump_engine_object(index_options), + self._engine_data_collector, + setup_by_user, + ) _flow_name_builder = _NameBuilder() + class _FlowBuilderState: """ A flow builder is used to build a flow. """ + engine_flow_builder: _engine.FlowBuilder field_name_builder: _NameBuilder @@ -322,17 +387,21 @@ def get_data_slice(self, v: Any) -> _engine.DataSlice: return v._state.engine_data_slice return self.engine_flow_builder.constant(encode_enriched_type(type(v)), v) + @dataclass class _SourceRefreshOptions: """ Options for refreshing a source. """ + refresh_interval: datetime.timedelta | None = None + class FlowBuilder: """ A flow builder is used to build a flow. """ + _state: _FlowBuilderState def __init__(self, state: _FlowBuilderState): @@ -344,10 +413,14 @@ def __str__(self): def __repr__(self): return repr(self._state.engine_flow_builder) - def add_source(self, spec: op.SourceSpec, /, *, - name: str | None = None, - refresh_interval: datetime.timedelta | None = None, - ) -> DataSlice: + def add_source( + self, + spec: op.SourceSpec, + /, + *, + name: str | None = None, + refresh_interval: datetime.timedelta | None = None, + ) -> DataSlice: """ Import a source to the flow. """ @@ -360,10 +433,13 @@ def add_source(self, spec: op.SourceSpec, /, *, dump_engine_object(spec), target_scope, self._state.field_name_builder.build_name( - name, prefix=_to_snake_case(_spec_kind(spec))+'_'), - dump_engine_object(_SourceRefreshOptions(refresh_interval=refresh_interval)), + name, prefix=_to_snake_case(_spec_kind(spec)) + "_" + ), + dump_engine_object( + _SourceRefreshOptions(refresh_interval=refresh_interval) + ), ), - name + name, ) def declare(self, spec: op.DeclarationSpec): @@ -372,18 +448,22 @@ def declare(self, spec: op.DeclarationSpec): """ self._state.engine_flow_builder.declare(dump_engine_object(spec)) + @dataclass class FlowLiveUpdaterOptions: """ Options for live updating a flow. """ + live_mode: bool = True print_stats: bool = False + class FlowLiveUpdater: """ A live updater for a flow. """ + _flow: Flow _options: FlowLiveUpdaterOptions _engine_live_updater: _engine.FlowLiveUpdater | None = None @@ -419,7 +499,8 @@ async def start_async(self) -> None: Start the live updater. """ self._engine_live_updater = await _engine.FlowLiveUpdater.create( - await self._flow.internal_flow_async(), dump_engine_object(self._options)) + await self._flow.internal_flow_async(), dump_engine_object(self._options) + ) def wait(self) -> None: """ @@ -456,22 +537,28 @@ class EvaluateAndDumpOptions: """ Options for evaluating and dumping a flow. """ + output_dir: str use_cache: bool = True + class Flow: """ A flow describes an indexing pipeline. """ + _name: str _full_name: str _lazy_engine_flow: Callable[[], _engine.Flow] - def __init__(self, name: str, full_name: str, engine_flow_creator: Callable[[], _engine.Flow]): + def __init__( + self, name: str, full_name: str, engine_flow_creator: Callable[[], _engine.Flow] + ): self._name = name self._full_name = full_name engine_flow = None lock = Lock() + def _lazy_engine_flow() -> _engine.Flow: nonlocal engine_flow, lock if engine_flow is None: @@ -479,6 +566,7 @@ def _lazy_engine_flow() -> _engine.Flow: if engine_flow is None: engine_flow = engine_flow_creator() return engine_flow + self._lazy_engine_flow = _lazy_engine_flow def _render_spec(self, verbose: bool = False) -> Tree: @@ -501,8 +589,10 @@ def build_tree(label: str, lines: list): return tree def _get_spec(self, verbose: bool = False) -> _engine.RenderedSpec: - return self._lazy_engine_flow().get_spec(output_mode="verbose" if verbose else "concise") - + return self._lazy_engine_flow().get_spec( + output_mode="verbose" if verbose else "concise" + ) + def _get_schema(self) -> list[tuple[str, str, str]]: return self._lazy_engine_flow().get_schema() @@ -538,7 +628,9 @@ async def update_async(self) -> _engine.IndexUpdateInfo: Update the index defined by the flow. Once the function returns, the index is fresh up to the moment when the function is called. """ - async with FlowLiveUpdater(self, FlowLiveUpdaterOptions(live_mode=False)) as updater: + async with FlowLiveUpdater( + self, FlowLiveUpdaterOptions(live_mode=False) + ) as updater: await updater.wait_async() return updater.update_stats() @@ -560,19 +652,26 @@ async def internal_flow_async(self) -> _engine.Flow: """ return await asyncio.to_thread(self.internal_flow) -def _create_lazy_flow(name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow: + +def _create_lazy_flow( + name: str | None, fl_def: Callable[[FlowBuilder, DataScope], None] +) -> Flow: """ Create a flow without really building it yet. The flow will be built the first time when it's really needed. """ flow_name = _flow_name_builder.build_name(name, prefix="_flow_") flow_full_name = get_full_flow_name(flow_name) + def _create_engine_flow() -> _engine.Flow: flow_builder_state = _FlowBuilderState(flow_full_name) root_scope = DataScope( - flow_builder_state, flow_builder_state.engine_flow_builder.root_scope()) + flow_builder_state, flow_builder_state.engine_flow_builder.root_scope() + ) fl_def(FlowBuilder(flow_builder_state), root_scope) - return flow_builder_state.engine_flow_builder.build_flow(execution_context.event_loop) + return flow_builder_state.engine_flow_builder.build_flow( + execution_context.event_loop + ) return Flow(flow_name, flow_full_name, _create_engine_flow) @@ -580,28 +679,34 @@ def _create_engine_flow() -> _engine.Flow: _flows_lock = Lock() _flows: dict[str, Flow] = {} + def get_full_flow_name(name: str) -> str: """ Get the full name of a flow. """ return f"{setting.get_app_namespace(trailing_delimiter='.')}{name}" + def add_flow_def(name: str, fl_def: Callable[[FlowBuilder, DataScope], None]) -> Flow: """Add a flow definition to the cocoindex library.""" - if not all(c.isalnum() or c == '_' for c in name): - raise ValueError(f"Flow name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed.") + if not all(c.isalnum() or c == "_" for c in name): + raise ValueError( + f"Flow name '{name}' contains invalid characters. Only alphanumeric characters and underscores are allowed." + ) with _flows_lock: if name in _flows: raise KeyError(f"Flow with name {name} already exists") fl = _flows[name] = _create_lazy_flow(name, fl_def) return fl -def flow_def(name = None) -> Callable[[Callable[[FlowBuilder, DataScope], None]], Flow]: + +def flow_def(name=None) -> Callable[[Callable[[FlowBuilder, DataScope], None]], Flow]: """ A decorator to wrap the flow definition. """ return lambda fl_def: add_flow_def(name or fl_def.__name__, fl_def) + def flow_names() -> list[str]: """ Get the names of all flows. @@ -609,6 +714,7 @@ def flow_names() -> list[str]: with _flows_lock: return list(_flows.keys()) + def flows() -> dict[str, Flow]: """ Get all flows. @@ -616,6 +722,7 @@ def flows() -> dict[str, Flow]: with _flows_lock: return dict(_flows) + def flow_by_name(name: str) -> Flow: """ Get a flow by name. @@ -623,12 +730,14 @@ def flow_by_name(name: str) -> Flow: with _flows_lock: return _flows[name] + def ensure_all_flows_built() -> None: """ Ensure all flows are built. """ execution_context.run(ensure_all_flows_built_async()) + async def ensure_all_flows_built_async() -> None: """ Ensure all flows are built. @@ -636,26 +745,39 @@ async def ensure_all_flows_built_async() -> None: for fl in flows().values(): await fl.internal_flow_async() -def update_all_flows(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]: + +def update_all_flows( + options: FlowLiveUpdaterOptions, +) -> dict[str, _engine.IndexUpdateInfo]: """ Update all flows. """ return execution_context.run(update_all_flows_async(options)) -async def update_all_flows_async(options: FlowLiveUpdaterOptions) -> dict[str, _engine.IndexUpdateInfo]: + +async def update_all_flows_async( + options: FlowLiveUpdaterOptions, +) -> dict[str, _engine.IndexUpdateInfo]: """ Update all flows. """ await ensure_all_flows_built_async() + async def _update_flow(name: str, fl: Flow) -> tuple[str, _engine.IndexUpdateInfo]: async with FlowLiveUpdater(fl, options) as updater: await updater.wait_async() return (name, updater.update_stats()) + fls = flows() - all_stats = await asyncio.gather(*(_update_flow(name, fl) for (name, fl) in fls.items())) + all_stats = await asyncio.gather( + *(_update_flow(name, fl) for (name, fl) in fls.items()) + ) return dict(all_stats) -def _get_data_slice_annotation_type(data_slice_type: Type[DataSlice[T]]) -> Type[T] | None: + +def _get_data_slice_annotation_type( + data_slice_type: Type[DataSlice[T]], +) -> Type[T] | None: type_args = get_args(data_slice_type) if data_slice_type is inspect.Parameter.empty or data_slice_type is DataSlice: return None @@ -663,16 +785,20 @@ def _get_data_slice_annotation_type(data_slice_type: Type[DataSlice[T]]) -> Type raise ValueError(f"Expect a DataSlice[T] type, but got {data_slice_type}") return type_args[0] + _transform_flow_name_builder = _NameBuilder() + class TransformFlowInfo(NamedTuple): engine_flow: _engine.TransientFlow result_decoder: Callable[[Any], T] + class TransformFlow(Generic[T]): """ A transient transformation flow that transforms in-memory data. """ + _flow_fn: Callable[..., DataSlice[T]] _flow_name: str _flow_arg_types: list[Any] @@ -682,10 +808,16 @@ class TransformFlow(Generic[T]): _lazy_flow_info: TransformFlowInfo | None = None def __init__( - self, flow_fn: Callable[..., DataSlice[T]], - flow_arg_types: Sequence[Any], /, name: str | None = None): + self, + flow_fn: Callable[..., DataSlice[T]], + flow_arg_types: Sequence[Any], + /, + name: str | None = None, + ): self._flow_fn = flow_fn - self._flow_name = _transform_flow_name_builder.build_name(name, prefix="_transform_flow_") + self._flow_name = _transform_flow_name_builder.build_name( + name, prefix="_transform_flow_" + ) self._flow_arg_types = list(flow_arg_types) self._lazy_lock = asyncio.Lock() @@ -712,28 +844,48 @@ async def _build_flow_info_async(self) -> TransformFlowInfo: if len(sig.parameters) != len(self._flow_arg_types): raise ValueError( f"Number of parameters in the flow function ({len(sig.parameters)}) " - f"does not match the number of argument types ({len(self._flow_arg_types)})") + f"does not match the number of argument types ({len(self._flow_arg_types)})" + ) kwargs: dict[str, DataSlice] = {} - for (param_name, param), param_type in zip(sig.parameters.items(), self._flow_arg_types): - if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY): - raise ValueError(f"Parameter `{param_name}` is not a parameter can be passed by name") + for (param_name, param), param_type in zip( + sig.parameters.items(), self._flow_arg_types + ): + if param.kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + raise ValueError( + f"Parameter `{param_name}` is not a parameter can be passed by name" + ) encoded_type = encode_enriched_type(param_type) if encoded_type is None: raise ValueError(f"Parameter `{param_name}` has no type annotation") - engine_ds = flow_builder_state.engine_flow_builder.add_direct_input(param_name, encoded_type) - kwargs[param_name] = DataSlice(_DataSliceState(flow_builder_state, engine_ds)) + engine_ds = flow_builder_state.engine_flow_builder.add_direct_input( + param_name, encoded_type + ) + kwargs[param_name] = DataSlice( + _DataSliceState(flow_builder_state, engine_ds) + ) output = self._flow_fn(**kwargs) flow_builder_state.engine_flow_builder.set_direct_output( - _data_slice_state(output).engine_data_slice) - engine_flow = await flow_builder_state.engine_flow_builder.build_transient_flow_async(execution_context.event_loop) + _data_slice_state(output).engine_data_slice + ) + engine_flow = ( + await flow_builder_state.engine_flow_builder.build_transient_flow_async( + execution_context.event_loop + ) + ) self._param_names = list(sig.parameters.keys()) - engine_return_type = _data_slice_state(output).engine_data_slice.data_type().schema() + engine_return_type = ( + _data_slice_state(output).engine_data_slice.data_type().schema() + ) python_return_type = _get_data_slice_annotation_type(sig.return_annotation) - result_decoder = make_engine_value_decoder([], engine_return_type['type'], python_return_type) + result_decoder = make_engine_value_decoder( + [], engine_return_type["type"], python_return_type + ) return TransformFlowInfo(engine_flow, result_decoder) @@ -776,18 +928,24 @@ def transform_flow() -> Callable[[Callable[..., DataSlice[T]]], TransformFlow[T] """ A decorator to wrap the transform function. """ + def _transform_flow_wrapper(fn: Callable[..., DataSlice[T]]): sig = inspect.signature(fn) arg_types = [] - for (param_name, param) in sig.parameters.items(): - if param.kind not in (inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY): - raise ValueError(f"Parameter `{param_name}` is not a parameter can be passed by name") + for param_name, param in sig.parameters.items(): + if param.kind not in ( + inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY, + ): + raise ValueError( + f"Parameter `{param_name}` is not a parameter can be passed by name" + ) value_type_annotation = _get_data_slice_annotation_type(param.annotation) if value_type_annotation is None: raise ValueError( f"Parameter `{param_name}` for {fn} has no value type annotation. " - "Please use `cocoindex.DataSlice[T]` where T is the type of the value.") + "Please use `cocoindex.DataSlice[T]` where T is the type of the value." + ) arg_types.append(value_type_annotation) _transform_flow = TransformFlow(fn, arg_types) diff --git a/python/cocoindex/functions.py b/python/cocoindex/functions.py index 8d65c9a72..2ce0fd95f 100644 --- a/python/cocoindex/functions.py +++ b/python/cocoindex/functions.py @@ -1,4 +1,5 @@ """All builtin functions.""" + from typing import Annotated, Any, TYPE_CHECKING from .typing import Float32, Vector, TypeAttr @@ -8,12 +9,15 @@ if TYPE_CHECKING: import sentence_transformers + class ParseJson(op.FunctionSpec): """Parse a text into a JSON object.""" + class SplitRecursively(op.FunctionSpec): """Split a document (in string) recursively.""" + class ExtractByLlm(op.FunctionSpec): """Extract information from a text using a LLM.""" @@ -21,6 +25,7 @@ class ExtractByLlm(op.FunctionSpec): output_type: type instruction: str | None = None + class SentenceTransformerEmbed(op.FunctionSpec): """ `SentenceTransformerEmbed` embeds a text into a vector space using the [SentenceTransformer](https://huggingface.co/sentence-transformers) library. @@ -30,9 +35,11 @@ class SentenceTransformerEmbed(op.FunctionSpec): model: The name of the SentenceTransformer model to use. args: Additional arguments to pass to the SentenceTransformer constructor. e.g. {"trust_remote_code": True} """ + model: str args: dict[str, Any] | None = None + @op.executor_class(gpu=True, cache=True, behavior_version=1) class SentenceTransformerEmbedExecutor: """Executor for SentenceTransformerEmbed.""" @@ -41,11 +48,15 @@ class SentenceTransformerEmbedExecutor: _model: "sentence_transformers.SentenceTransformer" def analyze(self, text): - import sentence_transformers # pylint: disable=import-outside-toplevel + import sentence_transformers # pylint: disable=import-outside-toplevel + args = self.spec.args or {} self._model = sentence_transformers.SentenceTransformer(self.spec.model, **args) dim = self._model.get_sentence_embedding_dimension() - return Annotated[Vector[Float32, dim], TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value)] + return Annotated[ + Vector[Float32, dim], + TypeAttr("cocoindex.io/vector_origin_text", text.analyzed_value), + ] def __call__(self, text: str) -> list[Float32]: return self._model.encode(text).tolist() diff --git a/python/cocoindex/index.py b/python/cocoindex/index.py index 6379425ac..a5ff06267 100644 --- a/python/cocoindex/index.py +++ b/python/cocoindex/index.py @@ -1,23 +1,29 @@ from enum import Enum from dataclasses import dataclass from typing import Sequence + + class VectorSimilarityMetric(Enum): COSINE_SIMILARITY = "CosineSimilarity" L2_DISTANCE = "L2Distance" INNER_PRODUCT = "InnerProduct" + @dataclass class VectorIndexDef: """ Define a vector index on a field. """ + field_name: str metric: VectorSimilarityMetric + @dataclass class IndexOptions: """ Options for an index. """ + primary_key_fields: Sequence[str] vector_indexes: Sequence[VectorIndexDef] = () diff --git a/python/cocoindex/lib.py b/python/cocoindex/lib.py index 6254fe6ac..2f00ea14b 100644 --- a/python/cocoindex/lib.py +++ b/python/cocoindex/lib.py @@ -1,6 +1,7 @@ """ Library level functions and states. """ + import warnings from typing import Callable, Any @@ -26,14 +27,16 @@ def start_server(settings: setting.ServerSettings): query.ensure_all_handlers_built() _engine.start_server(settings.__dict__) + def stop(): """Stop the cocoindex library.""" _engine.stop() + def main_fn( - settings: Any | None = None, - cocoindex_cmd: str | None = None, - ) -> Callable[[Callable], Callable]: + settings: Any | None = None, + cocoindex_cmd: str | None = None, +) -> Callable[[Callable], Callable]: """ DEPRECATED: The @cocoindex.main_fn() decorator is obsolete and has no effect. It will be removed in a future version, which will cause an AttributeError. @@ -64,9 +67,10 @@ def main_fn( "See cocoindex --help for more details.\n" "!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n", DeprecationWarning, - stacklevel=2 + stacklevel=2, ) def _main_wrapper(fn: Callable) -> Callable: return fn + return _main_wrapper diff --git a/python/cocoindex/llm.py b/python/cocoindex/llm.py index 5fcd0a732..b66f7b355 100644 --- a/python/cocoindex/llm.py +++ b/python/cocoindex/llm.py @@ -1,16 +1,20 @@ from dataclasses import dataclass from enum import Enum + class LlmApiType(Enum): """The type of LLM API to use.""" + OPENAI = "OpenAi" OLLAMA = "Ollama" GEMINI = "Gemini" ANTHROPIC = "Anthropic" + @dataclass class LlmSpec: """A specification for a LLM.""" + api_type: LlmApiType model: str address: str | None = None diff --git a/python/cocoindex/op.py b/python/cocoindex/op.py index 495a27e4a..667958605 100644 --- a/python/cocoindex/op.py +++ b/python/cocoindex/op.py @@ -1,6 +1,7 @@ """ Facilities for defining cocoindex operations. """ + import asyncio import dataclasses import inspect @@ -12,38 +13,50 @@ from .convert import encode_engine_value, make_engine_value_decoder from . import _engine # type: ignore + class OpCategory(Enum): """The category of the operation.""" + FUNCTION = "function" SOURCE = "source" STORAGE = "storage" DECLARATION = "declaration" + + @dataclass_transform() class SpecMeta(type): """Meta class for spec classes.""" + def __new__(mcs, name, bases, attrs, category: OpCategory | None = None): cls: type = super().__new__(mcs, name, bases, attrs) if category is not None: # It's the base class. - setattr(cls, '_op_category', category) + setattr(cls, "_op_category", category) else: # It's the specific class providing specific fields. cls = dataclasses.dataclass(cls) return cls -class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods + +class SourceSpec(metaclass=SpecMeta, category=OpCategory.SOURCE): # pylint: disable=too-few-public-methods """A source spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)""" -class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods + +class FunctionSpec(metaclass=SpecMeta, category=OpCategory.FUNCTION): # pylint: disable=too-few-public-methods """A function spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)""" -class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: disable=too-few-public-methods + +class StorageSpec(metaclass=SpecMeta, category=OpCategory.STORAGE): # pylint: disable=too-few-public-methods """A storage spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)""" -class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods + +class DeclarationSpec(metaclass=SpecMeta, category=OpCategory.DECLARATION): # pylint: disable=too-few-public-methods """A declaration spec. All its subclass can be instantiated similar to a dataclass, i.e. ClassName(field1=value1, field2=value2, ...)""" + + class Executor(Protocol): """An executor for an operation.""" + op_category: OpCategory @@ -64,6 +77,7 @@ def __call__(self, spec: dict[str, Any], *args, **kwargs): _gpu_dispatch_lock = asyncio.Lock() + @dataclasses.dataclass class OpArgs: """ @@ -72,26 +86,30 @@ class OpArgs: - behavior_version: The behavior version of the executor. Cache will be invalidated if it changes. Must be provided if `cache` is True. """ + gpu: bool = False cache: bool = False behavior_version: int | None = None + def _to_async_call(call: Callable) -> Callable[..., Awaitable[Any]]: if inspect.iscoroutinefunction(call): return call return lambda *args, **kwargs: asyncio.to_thread(lambda: call(*args, **kwargs)) + def _register_op_factory( - category: OpCategory, - expected_args: list[tuple[str, inspect.Parameter]], - expected_return, - executor_cls: type, - spec_cls: type, - op_args: OpArgs, - ): + category: OpCategory, + expected_args: list[tuple[str, inspect.Parameter]], + expected_return, + executor_cls: type, + spec_cls: type, + op_args: OpArgs, +): """ Register an op factory. """ + class _Fallback: def enable_cache(self): return op_args.cache @@ -122,15 +140,21 @@ def analyze(self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema): for arg in args: if next_param_idx >= len(expected_args): raise ValueError( - f"Too many arguments passed in: {len(args)} > {len(expected_args)}") + f"Too many arguments passed in: {len(args)} > {len(expected_args)}" + ) arg_name, arg_param = expected_args[next_param_idx] if arg_param.kind in ( - inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.VAR_KEYWORD): + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.VAR_KEYWORD, + ): raise ValueError( - f"Too many positional arguments passed in: {len(args)} > {next_param_idx}") + f"Too many positional arguments passed in: {len(args)} > {next_param_idx}" + ) self._args_decoders.append( make_engine_value_decoder( - [arg_name], arg.value_type['type'], arg_param.annotation)) + [arg_name], arg.value_type["type"], arg_param.annotation + ) + ) if arg_param.kind != inspect.Parameter.VAR_POSITIONAL: next_param_idx += 1 @@ -138,27 +162,50 @@ def analyze(self, *args: _engine.OpArgSchema, **kwargs: _engine.OpArgSchema): for kwarg_name, kwarg in kwargs.items(): expected_arg = next( - (arg for arg in expected_kwargs - if (arg[0] == kwarg_name and arg[1].kind in ( - inspect.Parameter.KEYWORD_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)) - or arg[1].kind == inspect.Parameter.VAR_KEYWORD), - None) + ( + arg + for arg in expected_kwargs + if ( + arg[0] == kwarg_name + and arg[1].kind + in ( + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + ) + or arg[1].kind == inspect.Parameter.VAR_KEYWORD + ), + None, + ) if expected_arg is None: - raise ValueError(f"Unexpected keyword argument passed in: {kwarg_name}") + raise ValueError( + f"Unexpected keyword argument passed in: {kwarg_name}" + ) arg_param = expected_arg[1] self._kwargs_decoders[kwarg_name] = make_engine_value_decoder( - [kwarg_name], kwarg.value_type['type'], arg_param.annotation) - - missing_args = [name for (name, arg) in expected_kwargs - if arg.default is inspect.Parameter.empty - and (arg.kind == inspect.Parameter.POSITIONAL_ONLY or - (arg.kind in (inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_OR_KEYWORD) - and name not in kwargs))] + [kwarg_name], kwarg.value_type["type"], arg_param.annotation + ) + + missing_args = [ + name + for (name, arg) in expected_kwargs + if arg.default is inspect.Parameter.empty + and ( + arg.kind == inspect.Parameter.POSITIONAL_ONLY + or ( + arg.kind + in ( + inspect.Parameter.KEYWORD_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and name not in kwargs + ) + ) + ] if len(missing_args) > 0: raise ValueError(f"Missing arguments: {', '.join(missing_args)}") - prepare_method = getattr(executor_cls, 'analyze', None) + prepare_method = getattr(executor_cls, "analyze", None) if prepare_method is not None: return prepare_method(self, *args, **kwargs) else: @@ -169,14 +216,18 @@ async def prepare(self): Prepare for execution. It's executed after `analyze` and before any `__call__` execution. """ - setup_method = getattr(super(), 'prepare', None) + setup_method = getattr(super(), "prepare", None) if setup_method is not None: await _to_async_call(setup_method)() async def __call__(self, *args, **kwargs): - decoded_args = (decoder(arg) for decoder, arg in zip(self._args_decoders, args)) - decoded_kwargs = {arg_name: self._kwargs_decoders[arg_name](arg) - for arg_name, arg in kwargs.items()} + decoded_args = ( + decoder(arg) for decoder, arg in zip(self._args_decoders, args) + ) + decoded_kwargs = { + arg_name: self._kwargs_decoders[arg_name](arg) + for arg_name, arg in kwargs.items() + } if op_args.gpu: # For GPU executions, data-level parallelism is applied, so we don't want to @@ -198,12 +249,14 @@ async def __call__(self, *args, **kwargs): if category == OpCategory.FUNCTION: _engine.register_function_factory( - spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass)) + spec_cls.__name__, _FunctionExecutorFactory(spec_cls, _WrappedClass) + ) else: raise ValueError(f"Unsupported executor type {category}") return _WrappedClass + def executor_class(**args) -> Callable[[type], type]: """ Decorate a class to provide an executor for an op. @@ -216,9 +269,9 @@ def _inner(cls: type[Executor]) -> type: """ # Use `__annotations__` instead of `get_type_hints`, to avoid resolving forward references. type_hints = cls.__annotations__ - if 'spec' not in type_hints: + if "spec" not in type_hints: raise TypeError("Expect a `spec` field with type hint") - spec_cls = resolve_forward_ref(type_hints['spec']) + spec_cls = resolve_forward_ref(type_hints["spec"]) sig = inspect.signature(cls.__call__) return _register_op_factory( category=spec_cls._op_category, @@ -226,10 +279,12 @@ def _inner(cls: type[Executor]) -> type: expected_return=sig.return_annotation, executor_cls=cls, spec_cls=spec_cls, - op_args=op_args) + op_args=op_args, + ) return _inner + def function(**args) -> Callable[[Callable], FunctionSpec]: """ Decorate a function to provide a function for an op. @@ -237,9 +292,8 @@ def function(**args) -> Callable[[Callable], FunctionSpec]: op_args = OpArgs(**args) def _inner(fn: Callable) -> FunctionSpec: - # Convert snake case to camel case. - op_name = ''.join(word.capitalize() for word in fn.__name__.split('_')) + op_name = "".join(word.capitalize() for word in fn.__name__.split("_")) sig = inspect.signature(fn) class _Executor: @@ -261,7 +315,8 @@ def __call__(self, *args, **kwargs): expected_return=sig.return_annotation, executor_cls=_Executor, spec_cls=_Spec, - op_args=op_args) + op_args=op_args, + ) return _Spec() diff --git a/python/cocoindex/query.py b/python/cocoindex/query.py index 73624b790..b3ba240d3 100644 --- a/python/cocoindex/query.py +++ b/python/cocoindex/query.py @@ -9,20 +9,24 @@ _handlers_lock = Lock() _handlers: dict[str, _engine.SimpleSemanticsQueryHandler] = {} + @dataclass class SimpleSemanticsQueryInfo: """ Additional information about the query. """ + similarity_metric: index.VectorSimilarityMetric query_vector: list[float] vector_field_name: str + @dataclass class QueryResult: """ A single result from the query. """ + data: dict[str, Any] score: float @@ -31,6 +35,7 @@ class SimpleSemanticsQueryHandler: """ A query handler that uses simple semantics to query the index. """ + _lazy_query_handler: Callable[[], _engine.SimpleSemanticsQueryHandler] def __init__( @@ -39,21 +44,27 @@ def __init__( flow: fl.Flow, target_name: str, query_transform_flow: Callable[..., fl.DataSlice], - default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY) -> None: - + default_similarity_metric: index.VectorSimilarityMetric = index.VectorSimilarityMetric.COSINE_SIMILARITY, + ) -> None: engine_handler = None lock = Lock() + def _lazy_handler() -> _engine.SimpleSemanticsQueryHandler: nonlocal engine_handler, lock if engine_handler is None: with lock: if engine_handler is None: engine_handler = _engine.SimpleSemanticsQueryHandler( - flow.internal_flow(), target_name, - fl.TransformFlow(query_transform_flow, [str]).internal_flow(), - default_similarity_metric.value) + flow.internal_flow(), + target_name, + fl.TransformFlow( + query_transform_flow, [str] + ).internal_flow(), + default_similarity_metric.value, + ) engine_handler.register_query_handler(name) return engine_handler + self._lazy_query_handler = _lazy_handler with _handlers_lock: @@ -65,24 +76,36 @@ def internal_handler(self) -> _engine.SimpleSemanticsQueryHandler: """ return self._lazy_query_handler() - def search(self, query: str, limit: int, vector_field_name: str | None = None, - similarity_metric: index.VectorSimilarityMetric | None = None - ) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]: + def search( + self, + query: str, + limit: int, + vector_field_name: str | None = None, + similarity_metric: index.VectorSimilarityMetric | None = None, + ) -> tuple[list[QueryResult], SimpleSemanticsQueryInfo]: """ Search the index with the given query, limit, vector field name, and similarity metric. """ internal_results, internal_info = self.internal_handler().search( - query, limit, vector_field_name, - similarity_metric.value if similarity_metric is not None else None) - results = [QueryResult(data=result['data'], score=result['score']) - for result in internal_results] + query, + limit, + vector_field_name, + similarity_metric.value if similarity_metric is not None else None, + ) + results = [ + QueryResult(data=result["data"], score=result["score"]) + for result in internal_results + ] info = SimpleSemanticsQueryInfo( - similarity_metric=index.VectorSimilarityMetric(internal_info['similarity_metric']), - query_vector=internal_info['query_vector'], - vector_field_name=internal_info['vector_field_name'] + similarity_metric=index.VectorSimilarityMetric( + internal_info["similarity_metric"] + ), + query_vector=internal_info["query_vector"], + vector_field_name=internal_info["vector_field_name"], ) return results, info + def ensure_all_handlers_built() -> None: """ Ensure all handlers are built. diff --git a/python/cocoindex/runtime.py b/python/cocoindex/runtime.py index 9f19f2d8f..baeb54d62 100644 --- a/python/cocoindex/runtime.py +++ b/python/cocoindex/runtime.py @@ -6,6 +6,8 @@ import threading import asyncio from typing import Coroutine + + class _ExecutionContext: _lock: threading.Lock _event_loop: asyncio.AbstractEventLoop | None = None @@ -19,11 +21,14 @@ def event_loop(self) -> asyncio.AbstractEventLoop: with self._lock: if self._event_loop is None: self._event_loop = asyncio.new_event_loop() - threading.Thread(target=self._event_loop.run_forever, daemon=True).start() + threading.Thread( + target=self._event_loop.run_forever, daemon=True + ).start() return self._event_loop def run(self, coro: Coroutine): """Run a coroutine in the event loop, blocking until it finishes. Return its result.""" return asyncio.run_coroutine_threadsafe(coro, self.event_loop).result() + execution_context = _ExecutionContext() diff --git a/python/cocoindex/setting.py b/python/cocoindex/setting.py index 4f382b0b5..ed39e25f8 100644 --- a/python/cocoindex/setting.py +++ b/python/cocoindex/setting.py @@ -1,43 +1,55 @@ """ Data types for settings of the cocoindex library. """ + import os from typing import Callable, Self, Any, overload from dataclasses import dataclass -_app_namespace: str = '' +_app_namespace: str = "" + def get_app_namespace(*, trailing_delimiter: str | None = None) -> str: """Get the application namespace. Append the `trailing_delimiter` if not empty.""" - if _app_namespace == '' or trailing_delimiter is None: + if _app_namespace == "" or trailing_delimiter is None: return _app_namespace - return f'{_app_namespace}{trailing_delimiter}' + return f"{_app_namespace}{trailing_delimiter}" + def split_app_namespace(full_name: str, delimiter: str) -> tuple[str, str]: """Split the full name into the application namespace and the rest.""" parts = full_name.split(delimiter, 1) if len(parts) == 1: - return '', parts[0] + return "", parts[0] return (parts[0], parts[1]) + def set_app_namespace(app_namespace: str): """Set the application namespace.""" global _app_namespace # pylint: disable=global-statement _app_namespace = app_namespace + @dataclass class DatabaseConnectionSpec: """ Connection spec for relational database. Used by both internal and target storage. """ + url: str user: str | None = None password: str | None = None -def _load_field(target: dict[str, Any], name: str, env_name: str, required: bool = False, - parse: Callable[[str], Any] | None = None): + +def _load_field( + target: dict[str, Any], + name: str, + env_name: str, + required: bool = False, + parse: Callable[[str], Any] | None = None, +): value = os.getenv(env_name) if value is None: if required: @@ -45,9 +57,11 @@ def _load_field(target: dict[str, Any], name: str, env_name: str, required: bool else: target[name] = value if parse is None else parse(value) + @dataclass class Settings: """Settings for the cocoindex library.""" + database: DatabaseConnectionSpec app_namespace: str = "" @@ -61,10 +75,11 @@ def from_env(cls) -> Self: _load_field(db_kwargs, "password", "COCOINDEX_DATABASE_PASSWORD") database = DatabaseConnectionSpec(**db_kwargs) - app_namespace = os.getenv("COCOINDEX_APP_NAMESPACE", '') + app_namespace = os.getenv("COCOINDEX_APP_NAMESPACE", "") return cls(database=database, app_namespace=app_namespace) + @dataclass class ServerSettings: """Settings for the cocoindex server.""" @@ -80,8 +95,12 @@ def from_env(cls) -> Self: """Load settings from environment variables.""" kwargs: dict[str, Any] = dict() _load_field(kwargs, "address", "COCOINDEX_SERVER_ADDRESS") - _load_field(kwargs, "cors_origins", "COCOINDEX_SERVER_CORS_ORIGINS", - parse=ServerSettings.parse_cors_origins) + _load_field( + kwargs, + "cors_origins", + "COCOINDEX_SERVER_CORS_ORIGINS", + parse=ServerSettings.parse_cors_origins, + ) return cls(**kwargs) @overload @@ -97,4 +116,8 @@ def parse_cors_origins(s): """ Parse the CORS origins from a string. """ - return [o for e in s.split(",") if (o := e.strip()) != ""] if s is not None else None \ No newline at end of file + return ( + [o for e in s.split(",") if (o := e.strip()) != ""] + if s is not None + else None + ) diff --git a/python/cocoindex/setup.py b/python/cocoindex/setup.py index fe019e912..daaf885a3 100644 --- a/python/cocoindex/setup.py +++ b/python/cocoindex/setup.py @@ -2,21 +2,25 @@ from . import setting from . import _engine # type: ignore + def sync_setup() -> _engine.SetupStatus: flow.ensure_all_flows_built() return _engine.sync_setup() + def drop_setup(flow_names: list[str]) -> _engine.SetupStatus: flow.ensure_all_flows_built() return _engine.drop_setup([flow.get_full_flow_name(name) for name in flow_names]) + def flow_names_with_setup() -> list[str]: result = [] for name in _engine.flow_names_with_setup(): - app_namespace, name = setting.split_app_namespace(name, '.') + app_namespace, name = setting.split_app_namespace(name, ".") if app_namespace == setting.get_app_namespace(): result.append(name) return result + def apply_setup_changes(setup_status: _engine.SetupStatus): _engine.apply_setup_changes(setup_status) diff --git a/python/cocoindex/sources.py b/python/cocoindex/sources.py index 8356ba2e8..d18f9934d 100644 --- a/python/cocoindex/sources.py +++ b/python/cocoindex/sources.py @@ -1,7 +1,9 @@ """All builtin sources.""" + from . import op import datetime + class LocalFile(op.SourceSpec): """Import data from local file system.""" @@ -40,4 +42,4 @@ class AmazonS3(op.SourceSpec): binary: bool = False included_patterns: list[str] | None = None excluded_patterns: list[str] | None = None - sqs_queue_url: str | None = None \ No newline at end of file + sqs_queue_url: str | None = None diff --git a/python/cocoindex/storages.py b/python/cocoindex/storages.py index c9d96c082..9d019fe8c 100644 --- a/python/cocoindex/storages.py +++ b/python/cocoindex/storages.py @@ -1,4 +1,5 @@ """All builtin storages.""" + from dataclasses import dataclass from typing import Sequence @@ -7,11 +8,14 @@ from .auth_registry import AuthEntryReference from .setting import DatabaseConnectionSpec + class Postgres(op.StorageSpec): """Storage powered by Postgres and pgvector.""" + database: AuthEntryReference[DatabaseConnectionSpec] | None = None table_name: str | None = None + @dataclass class Qdrant(op.StorageSpec): """Storage powered by Qdrant - https://qdrant.tech/.""" @@ -20,61 +24,77 @@ class Qdrant(op.StorageSpec): grpc_url: str = "http://localhost:6334/" api_key: str | None = None + @dataclass class Neo4jConnection: """Connection spec for Neo4j.""" + uri: str user: str password: str db: str | None = None + @dataclass class TargetFieldMapping: """Mapping for a graph element (node or relationship) field.""" + source: str # Field name for the node in the Knowledge Graph. # If unspecified, it's the same as `field_name`. target: str | None = None + @dataclass class NodeFromFields: """Spec for a referenced graph node, usually as part of a relationship.""" + label: str fields: list[TargetFieldMapping] + @dataclass class ReferencedNode: """Storage spec for a graph node.""" + label: str primary_key_fields: Sequence[str] vector_indexes: Sequence[index.VectorIndexDef] = () + @dataclass class Nodes: """Spec to map a row to a graph node.""" + kind = "Node" label: str + @dataclass class Relationships: """Spec to map a row to a graph relationship.""" + kind = "Relationship" rel_type: str source: NodeFromFields target: NodeFromFields + # For backwards compatibility only NodeMapping = Nodes RelationshipMapping = Relationships NodeReferenceMapping = NodeFromFields + class Neo4j(op.StorageSpec): """Graph storage powered by Neo4j.""" + connection: AuthEntryReference[Neo4jConnection] mapping: Nodes | Relationships + class Neo4jDeclaration(op.DeclarationSpec): """Declarations for Neo4j.""" diff --git a/python/cocoindex/tests/test_convert.py b/python/cocoindex/tests/test_convert.py index 33c55ff9e..f69a5d22b 100644 --- a/python/cocoindex/tests/test_convert.py +++ b/python/cocoindex/tests/test_convert.py @@ -7,6 +7,7 @@ from cocoindex.typing import encode_enriched_type from cocoindex.convert import encode_engine_value, make_engine_value_decoder + @dataclass class Order: order_id: str @@ -14,37 +15,44 @@ class Order: price: float extra_field: str = "default_extra" + @dataclass class Tag: name: str + @dataclass class Basket: items: list + @dataclass class Customer: name: str order: Order tags: list[Tag] | None = None + @dataclass class NestedStruct: customer: Customer orders: list[Order] count: int = 0 + class OrderNamedTuple(NamedTuple): order_id: str name: str price: float extra_field: str = "default_extra" + class CustomerNamedTuple(NamedTuple): name: str order: OrderNamedTuple tags: list[Tag] | None = None + def build_engine_value_decoder(engine_type_in_py, python_type=None): """ Helper to build a converter for the given engine-side type (as represented in Python). @@ -53,16 +61,19 @@ def build_engine_value_decoder(engine_type_in_py, python_type=None): engine_type = encode_enriched_type(engine_type_in_py)["type"] return make_engine_value_decoder([], engine_type, python_type or engine_type_in_py) + def test_encode_engine_value_basic_types(): assert encode_engine_value(123) == 123 assert encode_engine_value(3.14) == 3.14 assert encode_engine_value("hello") == "hello" assert encode_engine_value(True) is True + def test_encode_engine_value_uuid(): u = uuid.uuid4() assert encode_engine_value(u) == u.bytes + def test_encode_engine_value_date_time_types(): d = datetime.date(2024, 1, 1) assert encode_engine_value(d) == d @@ -71,35 +82,65 @@ def test_encode_engine_value_date_time_types(): dt = datetime.datetime(2024, 1, 1, 12, 30) assert encode_engine_value(dt) == dt + def test_encode_engine_value_struct(): order = Order(order_id="O123", name="mixed nuts", price=25.0) assert encode_engine_value(order) == ["O123", "mixed nuts", 25.0, "default_extra"] - + order_nt = OrderNamedTuple(order_id="O123", name="mixed nuts", price=25.0) - assert encode_engine_value(order_nt) == ["O123", "mixed nuts", 25.0, "default_extra"] + assert encode_engine_value(order_nt) == [ + "O123", + "mixed nuts", + 25.0, + "default_extra", + ] + def test_encode_engine_value_list_of_structs(): orders = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)] - assert encode_engine_value(orders) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] - - orders_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)] - assert encode_engine_value(orders_nt) == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] + assert encode_engine_value(orders) == [ + ["O1", "item1", 10.0, "default_extra"], + ["O2", "item2", 20.0, "default_extra"], + ] + + orders_nt = [ + OrderNamedTuple("O1", "item1", 10.0), + OrderNamedTuple("O2", "item2", 20.0), + ] + assert encode_engine_value(orders_nt) == [ + ["O1", "item1", 10.0, "default_extra"], + ["O2", "item2", 20.0, "default_extra"], + ] + def test_encode_engine_value_struct_with_list(): basket = Basket(items=["apple", "banana"]) assert encode_engine_value(basket) == [["apple", "banana"]] + def test_encode_engine_value_nested_struct(): customer = Customer(name="Alice", order=Order("O1", "item1", 10.0)) - assert encode_engine_value(customer) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None] - - customer_nt = CustomerNamedTuple(name="Alice", order=OrderNamedTuple("O1", "item1", 10.0)) - assert encode_engine_value(customer_nt) == ["Alice", ["O1", "item1", 10.0, "default_extra"], None] + assert encode_engine_value(customer) == [ + "Alice", + ["O1", "item1", 10.0, "default_extra"], + None, + ] + + customer_nt = CustomerNamedTuple( + name="Alice", order=OrderNamedTuple("O1", "item1", 10.0) + ) + assert encode_engine_value(customer_nt) == [ + "Alice", + ["O1", "item1", 10.0, "default_extra"], + None, + ] + def test_encode_engine_value_empty_list(): assert encode_engine_value([]) == [] assert encode_engine_value([[]]) == [[]] + def test_encode_engine_value_tuple(): assert encode_engine_value(()) == [] assert encode_engine_value((1, 2, 3)) == [1, 2, 3] @@ -107,9 +148,11 @@ def test_encode_engine_value_tuple(): assert encode_engine_value(([],)) == [[]] assert encode_engine_value(((),)) == [[]] + def test_encode_engine_value_none(): assert encode_engine_value(None) is None + def test_make_engine_value_decoder_basic_types(): for engine_type_in_py, value in [ (int, 42), @@ -121,80 +164,190 @@ def test_make_engine_value_decoder_basic_types(): decoder = build_engine_value_decoder(engine_type_in_py) assert decoder(value) == value + @pytest.mark.parametrize( "data_type, engine_val, expected", [ # All fields match (dataclass) - (Order, ["O123", "mixed nuts", 25.0, "default_extra"], Order("O123", "mixed nuts", 25.0, "default_extra")), + ( + Order, + ["O123", "mixed nuts", 25.0, "default_extra"], + Order("O123", "mixed nuts", 25.0, "default_extra"), + ), # All fields match (NamedTuple) - (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), + ( + OrderNamedTuple, + ["O123", "mixed nuts", 25.0, "default_extra"], + OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"), + ), # Extra field in engine value (should ignore extra) - (Order, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], Order("O123", "mixed nuts", 25.0, "default_extra")), - (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), + ( + Order, + ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], + Order("O123", "mixed nuts", 25.0, "default_extra"), + ), + ( + OrderNamedTuple, + ["O123", "mixed nuts", 25.0, "default_extra", "unexpected"], + OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"), + ), # Fewer fields in engine value (should fill with default) - (Order, ["O123", "mixed nuts", 0.0, "default_extra"], Order("O123", "mixed nuts", 0.0, "default_extra")), - (OrderNamedTuple, ["O123", "mixed nuts", 0.0, "default_extra"], OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra")), + ( + Order, + ["O123", "mixed nuts", 0.0, "default_extra"], + Order("O123", "mixed nuts", 0.0, "default_extra"), + ), + ( + OrderNamedTuple, + ["O123", "mixed nuts", 0.0, "default_extra"], + OrderNamedTuple("O123", "mixed nuts", 0.0, "default_extra"), + ), # More fields in engine value (should ignore extra) - (Order, ["O123", "mixed nuts", 25.0, "unexpected"], Order("O123", "mixed nuts", 25.0, "unexpected")), - (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "unexpected"], OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected")), + ( + Order, + ["O123", "mixed nuts", 25.0, "unexpected"], + Order("O123", "mixed nuts", 25.0, "unexpected"), + ), + ( + OrderNamedTuple, + ["O123", "mixed nuts", 25.0, "unexpected"], + OrderNamedTuple("O123", "mixed nuts", 25.0, "unexpected"), + ), # Truly extra field (should ignore the fifth field) - (Order, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], Order("O123", "mixed nuts", 25.0, "default_extra")), - (OrderNamedTuple, ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), + ( + Order, + ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], + Order("O123", "mixed nuts", 25.0, "default_extra"), + ), + ( + OrderNamedTuple, + ["O123", "mixed nuts", 25.0, "default_extra", "ignored"], + OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"), + ), # Missing optional field in engine value (tags=None) - (Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)), - (CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], None], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)), + ( + Customer, + ["Alice", ["O1", "item1", 10.0, "default_extra"], None], + Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None), + ), + ( + CustomerNamedTuple, + ["Alice", ["O1", "item1", 10.0, "default_extra"], None], + CustomerNamedTuple( + "Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None + ), + ), # Extra field in engine value for Customer (should ignore) - (Customer, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")])), - (CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip")])), + ( + Customer, + ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], + Customer( + "Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")] + ), + ), + ( + CustomerNamedTuple, + ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]], "extra"], + CustomerNamedTuple( + "Alice", + OrderNamedTuple("O1", "item1", 10.0, "default_extra"), + [Tag("vip")], + ), + ), # Missing optional field with default - (Order, ["O123", "mixed nuts", 25.0], Order("O123", "mixed nuts", 25.0, "default_extra")), - (OrderNamedTuple, ["O123", "mixed nuts", 25.0], OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra")), + ( + Order, + ["O123", "mixed nuts", 25.0], + Order("O123", "mixed nuts", 25.0, "default_extra"), + ), + ( + OrderNamedTuple, + ["O123", "mixed nuts", 25.0], + OrderNamedTuple("O123", "mixed nuts", 25.0, "default_extra"), + ), # Partial optional fields - (Customer, ["Alice", ["O1", "item1", 10.0]], Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None)), - (CustomerNamedTuple, ["Alice", ["O1", "item1", 10.0]], CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None)), - ] + ( + Customer, + ["Alice", ["O1", "item1", 10.0]], + Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), None), + ), + ( + CustomerNamedTuple, + ["Alice", ["O1", "item1", 10.0]], + CustomerNamedTuple( + "Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), None + ), + ), + ], ) def test_struct_decoder_cases(data_type, engine_val, expected): decoder = build_engine_value_decoder(data_type) assert decoder(engine_val) == expected + def test_make_engine_value_decoder_collections(): # List of structs (dataclass) decoder = build_engine_value_decoder(list[Order]) engine_val = [ ["O1", "item1", 10.0, "default_extra"], - ["O2", "item2", 20.0, "default_extra"] + ["O2", "item2", 20.0, "default_extra"], ] - assert decoder(engine_val) == [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")] - + assert decoder(engine_val) == [ + Order("O1", "item1", 10.0, "default_extra"), + Order("O2", "item2", 20.0, "default_extra"), + ] + # List of structs (NamedTuple) decoder = build_engine_value_decoder(list[OrderNamedTuple]) - assert decoder(engine_val) == [OrderNamedTuple("O1", "item1", 10.0, "default_extra"), OrderNamedTuple("O2", "item2", 20.0, "default_extra")] - + assert decoder(engine_val) == [ + OrderNamedTuple("O1", "item1", 10.0, "default_extra"), + OrderNamedTuple("O2", "item2", 20.0, "default_extra"), + ] + # Struct with list field decoder = build_engine_value_decoder(Customer) - engine_val = ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"], ["premium"]]] - assert decoder(engine_val) == Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")]) - + engine_val = [ + "Alice", + ["O1", "item1", 10.0, "default_extra"], + [["vip"], ["premium"]], + ] + assert decoder(engine_val) == Customer( + "Alice", + Order("O1", "item1", 10.0, "default_extra"), + [Tag("vip"), Tag("premium")], + ) + # NamedTuple with list field decoder = build_engine_value_decoder(CustomerNamedTuple) - assert decoder(engine_val) == CustomerNamedTuple("Alice", OrderNamedTuple("O1", "item1", 10.0, "default_extra"), [Tag("vip"), Tag("premium")]) - + assert decoder(engine_val) == CustomerNamedTuple( + "Alice", + OrderNamedTuple("O1", "item1", 10.0, "default_extra"), + [Tag("vip"), Tag("premium")], + ) + # Struct with struct field decoder = build_engine_value_decoder(NestedStruct) engine_val = [ ["Alice", ["O1", "item1", 10.0, "default_extra"], [["vip"]]], - [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]], - 2 + [ + ["O1", "item1", 10.0, "default_extra"], + ["O2", "item2", 20.0, "default_extra"], + ], + 2, ] assert decoder(engine_val) == NestedStruct( Customer("Alice", Order("O1", "item1", 10.0, "default_extra"), [Tag("vip")]), - [Order("O1", "item1", 10.0, "default_extra"), Order("O2", "item2", 20.0, "default_extra")], - 2 + [ + Order("O1", "item1", 10.0, "default_extra"), + Order("O2", "item2", 20.0, "default_extra"), + ], + 2, ) + def make_engine_order(fields): - return make_dataclass('EngineOrder', fields) + return make_dataclass("EngineOrder", fields) + def make_python_order(fields, defaults=None): if defaults is None: @@ -205,7 +358,8 @@ def make_python_order(fields, defaults=None): ordered_fields = non_default_fields + default_fields # Prepare the namespace for defaults (only for fields at the end) namespace = {k: defaults[k] for k, _ in default_fields} - return make_dataclass('PythonOrder', ordered_fields, namespace=namespace) + return make_dataclass("PythonOrder", ordered_fields, namespace=namespace) + @pytest.mark.parametrize( "engine_fields, python_fields, python_defaults, engine_val, expected_python_val", @@ -266,9 +420,11 @@ def make_python_order(fields, defaults=None): ["O123", "mixed nuts", 25.0], ("O123", "mixed nuts"), ), - ] + ], ) -def test_field_position_cases(engine_fields, python_fields, python_defaults, engine_val, expected_python_val): +def test_field_position_cases( + engine_fields, python_fields, python_defaults, engine_val, expected_python_val +): EngineOrder = make_engine_order(engine_fields) PythonOrder = make_python_order(python_fields, python_defaults) decoder = build_engine_value_decoder(EngineOrder, PythonOrder) @@ -277,36 +433,57 @@ def test_field_position_cases(engine_fields, python_fields, python_defaults, eng # Instantiate using keyword arguments (order doesn't matter) assert decoder(engine_val) == PythonOrder(**expected_dict) + def test_roundtrip_ltable(): t = list[Order] value = [Order("O1", "item1", 10.0), Order("O2", "item2", 20.0)] encoded = encode_engine_value(value) - assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] + assert encoded == [ + ["O1", "item1", 10.0, "default_extra"], + ["O2", "item2", 20.0, "default_extra"], + ] decoded = build_engine_value_decoder(t)(encoded) assert decoded == value - + t_nt = list[OrderNamedTuple] - value_nt = [OrderNamedTuple("O1", "item1", 10.0), OrderNamedTuple("O2", "item2", 20.0)] + value_nt = [ + OrderNamedTuple("O1", "item1", 10.0), + OrderNamedTuple("O2", "item2", 20.0), + ] encoded = encode_engine_value(value_nt) - assert encoded == [["O1", "item1", 10.0, "default_extra"], ["O2", "item2", 20.0, "default_extra"]] + assert encoded == [ + ["O1", "item1", 10.0, "default_extra"], + ["O2", "item2", 20.0, "default_extra"], + ] decoded = build_engine_value_decoder(t_nt)(encoded) assert decoded == value_nt + def test_roundtrip_ktable_str_key(): t = dict[str, Order] value = {"K1": Order("O1", "item1", 10.0), "K2": Order("O2", "item2", 20.0)} encoded = encode_engine_value(value) - assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]] + assert encoded == [ + ["K1", "O1", "item1", 10.0, "default_extra"], + ["K2", "O2", "item2", 20.0, "default_extra"], + ] decoded = build_engine_value_decoder(t)(encoded) assert decoded == value - + t_nt = dict[str, OrderNamedTuple] - value_nt = {"K1": OrderNamedTuple("O1", "item1", 10.0), "K2": OrderNamedTuple("O2", "item2", 20.0)} + value_nt = { + "K1": OrderNamedTuple("O1", "item1", 10.0), + "K2": OrderNamedTuple("O2", "item2", 20.0), + } encoded = encode_engine_value(value_nt) - assert encoded == [["K1", "O1", "item1", 10.0, "default_extra"], ["K2", "O2", "item2", 20.0, "default_extra"]] + assert encoded == [ + ["K1", "O1", "item1", 10.0, "default_extra"], + ["K2", "O2", "item2", 20.0, "default_extra"], + ] decoded = build_engine_value_decoder(t_nt)(encoded) assert decoded == value_nt + def test_roundtrip_ktable_struct_key(): @dataclass(frozen=True) class OrderKey: @@ -314,22 +491,35 @@ class OrderKey: version: int t = dict[OrderKey, Order] - value = {OrderKey("A", 3): Order("O1", "item1", 10.0), OrderKey("B", 4): Order("O2", "item2", 20.0)} + value = { + OrderKey("A", 3): Order("O1", "item1", 10.0), + OrderKey("B", 4): Order("O2", "item2", 20.0), + } encoded = encode_engine_value(value) - assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"], - [["B", 4], "O2", "item2", 20.0, "default_extra"]] + assert encoded == [ + [["A", 3], "O1", "item1", 10.0, "default_extra"], + [["B", 4], "O2", "item2", 20.0, "default_extra"], + ] decoded = build_engine_value_decoder(t)(encoded) assert decoded == value - + t_nt = dict[OrderKey, OrderNamedTuple] - value_nt = {OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0), OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0)} + value_nt = { + OrderKey("A", 3): OrderNamedTuple("O1", "item1", 10.0), + OrderKey("B", 4): OrderNamedTuple("O2", "item2", 20.0), + } encoded = encode_engine_value(value_nt) - assert encoded == [[["A", 3], "O1", "item1", 10.0, "default_extra"], - [["B", 4], "O2", "item2", 20.0, "default_extra"]] + assert encoded == [ + [["A", 3], "O1", "item1", 10.0, "default_extra"], + [["B", 4], "O2", "item2", 20.0, "default_extra"], + ] decoded = build_engine_value_decoder(t_nt)(encoded) assert decoded == value_nt + IntVectorType = cocoindex.Vector[int, Literal[5]] + + def test_vector_as_vector() -> None: value: IntVectorType = [1, 2, 3, 4, 5] encoded = encode_engine_value(value) @@ -337,11 +527,13 @@ def test_vector_as_vector() -> None: decoded = build_engine_value_decoder(IntVectorType)(encoded) assert decoded == value + ListIntType = list[int] + + def test_vector_as_list() -> None: value: ListIntType = [1, 2, 3, 4, 5] encoded = encode_engine_value(value) assert encoded == [1, 2, 3, 4, 5] decoded = build_engine_value_decoder(ListIntType)(encoded) assert decoded == value - diff --git a/python/cocoindex/typing.py b/python/cocoindex/typing.py index d1dd6ccdb..479a34081 100644 --- a/python/cocoindex/typing.py +++ b/python/cocoindex/typing.py @@ -5,14 +5,28 @@ import types import inspect import uuid -from typing import Annotated, NamedTuple, Any, TypeVar, TYPE_CHECKING, overload, Sequence, Generic, Literal, Protocol +from typing import ( + Annotated, + NamedTuple, + Any, + TypeVar, + TYPE_CHECKING, + overload, + Sequence, + Generic, + Literal, + Protocol, +) + class VectorInfo(NamedTuple): dim: int | None + class TypeKind(NamedTuple): kind: str + class TypeAttr: key: str value: Any @@ -21,27 +35,31 @@ def __init__(self, key: str, value: Any): self.key = key self.value = value + Annotation = TypeKind | TypeAttr | VectorInfo -Float32 = Annotated[float, TypeKind('Float32')] -Float64 = Annotated[float, TypeKind('Float64')] -Range = Annotated[tuple[int, int], TypeKind('Range')] -Json = Annotated[Any, TypeKind('Json')] -LocalDateTime = Annotated[datetime.datetime, TypeKind('LocalDateTime')] -OffsetDateTime = Annotated[datetime.datetime, TypeKind('OffsetDateTime')] +Float32 = Annotated[float, TypeKind("Float32")] +Float64 = Annotated[float, TypeKind("Float64")] +Range = Annotated[tuple[int, int], TypeKind("Range")] +Json = Annotated[Any, TypeKind("Json")] +LocalDateTime = Annotated[datetime.datetime, TypeKind("LocalDateTime")] +OffsetDateTime = Annotated[datetime.datetime, TypeKind("OffsetDateTime")] if TYPE_CHECKING: - T_co = TypeVar('T_co', covariant=True) - Dim_co = TypeVar('Dim_co', bound=int, covariant=True) + T_co = TypeVar("T_co", covariant=True) + Dim_co = TypeVar("Dim_co", bound=int, covariant=True) class Vector(Protocol, Generic[T_co, Dim_co]): """Vector[T, Dim] is a special typing alias for a list[T] with optional dimension info""" + def __getitem__(self, index: int) -> T_co: ... def __len__(self) -> int: ... else: + class Vector: # type: ignore[unreachable] - """ A special typing alias for a list[T] with optional dimension info """ + """A special typing alias for a list[T] with optional dimension info""" + def __class_getitem__(self, params): if not isinstance(params, tuple): # Only element type provided @@ -54,32 +72,40 @@ def __class_getitem__(self, params): dim = typing.get_args(dim)[0] # Extract the literal value return Annotated[list[elem_type], VectorInfo(dim=dim)] -TABLE_TYPES = ('KTable', 'LTable') -KEY_FIELD_NAME = '_key' + +TABLE_TYPES = ("KTable", "LTable") +KEY_FIELD_NAME = "_key" ElementType = type | tuple[type, type] + def is_namedtuple_type(t) -> bool: return isinstance(t, type) and issubclass(t, tuple) and hasattr(t, "_fields") + def _is_struct_type(t) -> bool: - return isinstance(t, type) and (dataclasses.is_dataclass(t) or is_namedtuple_type(t)) + return isinstance(t, type) and ( + dataclasses.is_dataclass(t) or is_namedtuple_type(t) + ) + @dataclasses.dataclass class AnalyzedTypeInfo: """ Analyzed info of a Python type. """ + kind: str vector_info: VectorInfo | None # For Vector - elem_type: ElementType | None # For Vector and Table + elem_type: ElementType | None # For Vector and Table - key_type: type | None # For element of KTable - struct_type: type | None # For Struct, a dataclass or namedtuple + key_type: type | None # For element of KTable + struct_type: type | None # For Struct, a dataclass or namedtuple attrs: dict[str, Any] | None nullable: bool = False + def analyze_type_info(t) -> AnalyzedTypeInfo: """ Analyze a Python type and return the analyzed info. @@ -100,10 +126,13 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: t = t.__origin__ elif base_type is types.UnionType: possible_types = typing.get_args(t) - non_none_types = [arg for arg in possible_types if arg not in (None, types.NoneType)] + non_none_types = [ + arg for arg in possible_types if arg not in (None, types.NoneType) + ] if len(non_none_types) != 1: raise ValueError( - f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}") + f"Expect exactly one non-None choice for Union type, but got {len(non_none_types)}: {t}" + ) t = non_none_types[0] if len(possible_types) > 1: nullable = True @@ -130,8 +159,8 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: struct_type = t if kind is None: - kind = 'Struct' - elif kind != 'Struct': + kind = "Struct" + elif kind != "Struct": raise ValueError(f"Unexpected type kind for struct: {kind}") elif base_type is collections.abc.Sequence or base_type is list: args = typing.get_args(t) @@ -139,40 +168,42 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: if kind is None: if _is_struct_type(elem_type): - kind = 'LTable' + kind = "LTable" if vector_info is not None: - raise ValueError("Vector element must be a simple type, not a struct") + raise ValueError( + "Vector element must be a simple type, not a struct" + ) else: - kind = 'Vector' + kind = "Vector" if vector_info is None: vector_info = VectorInfo(dim=None) - elif not (kind == 'Vector' or kind in TABLE_TYPES): + elif not (kind == "Vector" or kind in TABLE_TYPES): raise ValueError(f"Unexpected type kind for list: {kind}") elif base_type is collections.abc.Mapping or base_type is dict: args = typing.get_args(t) elem_type = (args[0], args[1]) - kind = 'KTable' + kind = "KTable" elif kind is None: if t is bytes: - kind = 'Bytes' + kind = "Bytes" elif t is str: - kind = 'Str' + kind = "Str" elif t is bool: - kind = 'Bool' + kind = "Bool" elif t is int: - kind = 'Int64' + kind = "Int64" elif t is float: - kind = 'Float64' + kind = "Float64" elif t is uuid.UUID: - kind = 'Uuid' + kind = "Uuid" elif t is datetime.date: - kind = 'Date' + kind = "Date" elif t is datetime.time: - kind = 'Time' + kind = "Time" elif t is datetime.datetime: - kind = 'OffsetDateTime' + kind = "OffsetDateTime" elif t is datetime.timedelta: - kind = 'TimeDelta' + kind = "TimeDelta" else: raise ValueError(f"type unsupported yet: {t}") @@ -186,8 +217,12 @@ def analyze_type_info(t) -> AnalyzedTypeInfo: nullable=nullable, ) -def _encode_fields_schema(struct_type: type, key_type: type | None = None) -> list[dict[str, Any]]: + +def _encode_fields_schema( + struct_type: type, key_type: type | None = None +) -> list[dict[str, Any]]: result = [] + def add_field(name: str, t) -> None: try: type_info = encode_enriched_type_info(analyze_type_info(t)) @@ -197,7 +232,7 @@ def add_field(name: str, t) -> None: f"{struct_type.__name__}.{name}: {t}" ) raise - type_info['name'] = name + type_info["name"] = name result.append(type_info) if key_type is not None: @@ -212,53 +247,60 @@ def add_field(name: str, t) -> None: return result + def _encode_type(type_info: AnalyzedTypeInfo) -> dict[str, Any]: - encoded_type: dict[str, Any] = { 'kind': type_info.kind } + encoded_type: dict[str, Any] = {"kind": type_info.kind} - if type_info.kind == 'Struct': + if type_info.kind == "Struct": if type_info.struct_type is None: raise ValueError("Struct type must have a dataclass or namedtuple type") - encoded_type['fields'] = _encode_fields_schema(type_info.struct_type, type_info.key_type) + encoded_type["fields"] = _encode_fields_schema( + type_info.struct_type, type_info.key_type + ) if doc := inspect.getdoc(type_info.struct_type): - encoded_type['description'] = doc + encoded_type["description"] = doc - elif type_info.kind == 'Vector': + elif type_info.kind == "Vector": if type_info.vector_info is None: raise ValueError("Vector type must have a vector info") if type_info.elem_type is None: raise ValueError("Vector type must have an element type") - encoded_type['element_type'] = _encode_type(analyze_type_info(type_info.elem_type)) - encoded_type['dimension'] = type_info.vector_info.dim + encoded_type["element_type"] = _encode_type( + analyze_type_info(type_info.elem_type) + ) + encoded_type["dimension"] = type_info.vector_info.dim elif type_info.kind in TABLE_TYPES: if type_info.elem_type is None: raise ValueError(f"{type_info.kind} type must have an element type") row_type_info = analyze_type_info(type_info.elem_type) - encoded_type['row'] = _encode_type(row_type_info) + encoded_type["row"] = _encode_type(row_type_info) return encoded_type + def encode_enriched_type_info(enriched_type_info: AnalyzedTypeInfo) -> dict[str, Any]: """ Encode an enriched type info to a CocoIndex engine's type representation """ - encoded: dict[str, Any] = {'type': _encode_type(enriched_type_info)} + encoded: dict[str, Any] = {"type": _encode_type(enriched_type_info)} if enriched_type_info.attrs is not None: - encoded['attrs'] = enriched_type_info.attrs + encoded["attrs"] = enriched_type_info.attrs if enriched_type_info.nullable: - encoded['nullable'] = True + encoded["nullable"] = True return encoded + @overload -def encode_enriched_type(t: None) -> None: - ... +def encode_enriched_type(t: None) -> None: ... + @overload -def encode_enriched_type(t: Any) -> dict[str, Any]: - ... +def encode_enriched_type(t: Any) -> dict[str, Any]: ... + def encode_enriched_type(t) -> dict[str, Any] | None: """ @@ -269,7 +311,8 @@ def encode_enriched_type(t) -> dict[str, Any] | None: return encode_enriched_type_info(analyze_type_info(t)) + def resolve_forward_ref(t): if t is str: - return eval(t) # pylint: disable=eval-used + return eval(t) # pylint: disable=eval-used return t diff --git a/python/cocoindex/utils.py b/python/cocoindex/utils.py index e5be399a0..7a3177980 100644 --- a/python/cocoindex/utils.py +++ b/python/cocoindex/utils.py @@ -1,9 +1,17 @@ from .flow import Flow from .setting import get_app_namespace -def get_target_storage_default_name(flow: Flow, target_name: str, delimiter: str = "__") -> str: + +def get_target_storage_default_name( + flow: Flow, target_name: str, delimiter: str = "__" +) -> str: """ Get the default name for a target. It's used as the underlying storage name (e.g. a table, a collection, etc.) followed by most storage backends, if not explicitly specified. """ - return get_app_namespace(trailing_delimiter=delimiter) + flow.name + delimiter + target_name + return ( + get_app_namespace(trailing_delimiter=delimiter) + + flow.name + + delimiter + + target_name + ) diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 000000000..5bae73023 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,5 @@ +[format] +quote-style = "double" +indent-style = "space" +skip-magic-trailing-comma = false +line-ending = "lf"