diff --git a/apps/crawl-to-rag/Dockerfile b/apps/crawl-to-rag/Dockerfile index 0f3a6e6c..89f24688 100644 --- a/apps/crawl-to-rag/Dockerfile +++ b/apps/crawl-to-rag/Dockerfile @@ -8,8 +8,9 @@ FROM aperturedata/workflows-base ENV APP_NAME=workflows-crawl-to-rag -# Needed for text-embeddings -RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* +# Install dependencies for embeddings +RUN pip install --no-cache-dir -r /app/embeddings/requirements_cpu.txt +RUN pip install --no-cache-dir -r /app/embeddings/requirements.txt # copy in the app directories COPY --from=crawl-website /app /workflows/crawl-website @@ -28,6 +29,7 @@ RUN pip install --no-cache-dir -r /requirements.txt COPY --from=rag /requirements.txt /requirements.txt RUN pip install --no-cache-dir -r /requirements.txt + EXPOSE 8000 COPY app.sh /app/ diff --git a/apps/rag/Dockerfile b/apps/rag/Dockerfile index 47fd5c11..e8f16d7a 100644 --- a/apps/rag/Dockerfile +++ b/apps/rag/Dockerfile @@ -8,6 +8,10 @@ RUN apt-get update && apt-get install -y git && rm -rf /var/lib/apt/lists/* COPY requirements.txt / RUN pip install --no-cache-dir -r /requirements.txt +# Install dependencies for embeddings +RUN pip install --no-cache-dir -r /app/embeddings/requirements_cpu.txt +RUN pip install --no-cache-dir -r /app/embeddings/requirements.txt + # We prefer to cache models in the docker image rather than load them # at run time. COPY app/llm.py /app/llm.py diff --git a/apps/sql-server/Dockerfile b/apps/sql-server/Dockerfile index 1376b089..908018c2 100644 --- a/apps/sql-server/Dockerfile +++ b/apps/sql-server/Dockerfile @@ -5,7 +5,7 @@ ENV APP_NAME=workflows-sql-server ENV POSTGRES_VERSION=17 ARG MULTICORN_VERSION=3.0 -# Add PGDG repository and install PostgreSQL 17 +# Add PGDG repository and install PostgreSQL RUN apt-get update && apt-get install -y wget gnupg lsb-release \ && echo "deb http://apt.postgresql.org/pub/repos/apt $(lsb_release -cs)-pgdg main" \ > /etc/apt/sources.list.d/pgdg.list \ @@ -23,28 +23,39 @@ RUN echo "listen_addresses = '*'" >> /etc/postgresql/${POSTGRES_VERSION}/main/po RUN echo "host all all 0.0.0.0/0 md5" >> /etc/postgresql/17/main/pg_hba.conf # Postgres/Multicorn insists on using the system Python, so we need to disable the virtual environment +# Store current VIRTUAL_ENV and PATH values +ENV OLD_VIRTUAL_ENV="${VIRTUAL_ENV}" +ENV OLD_PATH="${PATH}" +ENV OLD_PYTHONPATH="${PYTHONPATH}" + +# Disable virtual environment ENV VIRTUAL_ENV= ENV PATH="/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin" +ENV PYTHONPATH="" # Install multicorn2 Python module into system Python RUN /usr/bin/python3 -m pip install --no-cache-dir "git+https://github.com/pgsql-io/multicorn2.git" -# Build and install multicorn2 Postgres extension -RUN git clone https://github.com/pgsql-io/multicorn2.git /multicorn2 \ +# Build and install multicorn2 Postgres extension +RUN git clone --single-branch --branch main https://github.com/pgsql-io/multicorn2.git /multicorn2 \ && cd /multicorn2 \ && make PYTHON=/usr/bin/python3 \ && make install -RUN pip install --no-cache-dir aperturedb dotenv +RUN pip install --no-cache-dir aperturedb dotenv numpy pydantic + +# Install dependencies for embeddings +RUN pip install --no-cache-dir -r /app/embeddings/requirements_cpu.txt +RUN pip install --no-cache-dir -r /app/embeddings/requirements.txt # Copy and install our FDW into system Python COPY fdw /fdw RUN cd /fdw && /usr/bin/python3 -m pip install . # Restore virtual environment -ENV VIRTUAL_ENV=/opt/venv -ENV PATH="/opt/venv/bin:/opt/venv/lib/python3.10/site-packages:$PATH" -ENV PYTHONPATH="/app:/opt/venv/lib/python3.10/site-packages" +ENV VIRTUAL_ENV=${OLD_VIRTUAL_ENV} +ENV PATH="${OLD_PATH}" +ENV PYTHONPATH="/app:${OLD_PYTHONPATH}" # Install application requirements COPY requirements.txt /requirements.txt diff --git a/apps/sql-server/app/sql/functions.sql b/apps/sql-server/app/sql/functions.sql index 1fedbcdc..b69e9cad 100644 --- a/apps/sql-server/app/sql/functions.sql +++ b/apps/sql-server/app/sql/functions.sql @@ -102,4 +102,43 @@ $$ LANGUAGE SQL IMMUTABLE; CREATE OR REPLACE FUNCTION OPERATIONS(VARIADIC ops jsonb[]) RETURNS jsonb AS $$ SELECT jsonb_agg(op) FROM unnest($1) AS op -$$ LANGUAGE SQL IMMUTABLE; \ No newline at end of file +$$ LANGUAGE SQL IMMUTABLE; + + +-- Find similar + +CREATE OR REPLACE FUNCTION FIND_SIMILAR( + text TEXT DEFAULT NULL, + image BYTEA DEFAULT NULL, + vector JSONB DEFAULT NULL, + k INT DEFAULT 10, + knn_first BOOLEAN DEFAULT TRUE +) RETURNS JSONB AS $$ +DECLARE + mode_count INT; +BEGIN + -- Count how many modes are specified + mode_count := (CASE WHEN text IS NOT NULL THEN 1 ELSE 0 END) + + (CASE WHEN image IS NOT NULL THEN 1 ELSE 0 END) + + (CASE WHEN vector IS NOT NULL THEN 1 ELSE 0 END); + + IF mode_count != 1 THEN + RAISE EXCEPTION 'FIND_SIMILAR requires exactly one of text, image, or vector'; + END IF; + + IF k IS NULL OR k <= 0 THEN + RAISE EXCEPTION 'k must be a positive integer'; + END IF; + + RETURN jsonb_build_object( + 'type', 'find_similar', + 'text', text, + 'image', image, + 'vector', vector, + 'k_neighbors', k, + 'knn_first', knn_first + ); +END; +$$ LANGUAGE plpgsql IMMUTABLE; + +COMMENT ON FUNCTION FIND_SIMILAR IS 'Find similar items based on one of text, image, or vector.'; diff --git a/apps/sql-server/fdw/fdw/__init__.py b/apps/sql-server/fdw/fdw/__init__.py index 8e247bf6..c7874b20 100644 --- a/apps/sql-server/fdw/fdw/__init__.py +++ b/apps/sql-server/fdw/fdw/__init__.py @@ -1,16 +1,20 @@ +from .common import get_pool, get_log_level, compact_pretty_json +from .table import TableOptions +from .column import ColumnOptions +from multicorn import ForeignDataWrapper, Qual +import numpy as np import atexit import json import os import logging -from aperturedb.CommonLibrary import create_connector from datetime import datetime import sys -from multicorn import TableDefinition, ColumnDefinition, ForeignDataWrapper from itertools import zip_longest -from typing import Optional, Set, Tuple, Generator, List, Dict -from dotenv import load_dotenv -from collections import defaultdict -from .common import get_pool, get_log_level, TableOptions, ColumnOptions +from typing import Optional, Set, Tuple, Generator, List, Dict, Any, Iterable + +import pydantic +import sys +import importlib.util # Configure logging @@ -22,7 +26,6 @@ logging.basicConfig(level=log_level, force=True) logger = logging.getLogger(__name__) logger.setLevel(log_level) - logger.addHandler(handler) logger.propagate = False @@ -37,17 +40,6 @@ def flush_logs(): atexit.register(flush_logs) - -# Mapping from ApertureDB types to PostgreSQL types. -TYPE_MAP = { - "number": "double precision", - "string": "text", - "boolean": "boolean", - "datetime": "timestamptz", - "json": "jsonb", - "blob": "bytea", -} - # Queries are processed in batches, but the client doesn't know because result rows are yielded one by one. BATCH_SIZE = 100 BATCH_SIZE_WITH_BLOBS = 10 @@ -62,7 +54,51 @@ class FDW(ForeignDataWrapper): It also passes options for each table and column that are passed into `__init__`. """ + @classmethod + def import_schema(cls, schema, srv_options, options, restriction_type, restricts): + """ + Import the schema from ApertureDB and return a list of TableDefinitions. + This method is called when the foreign data wrapper is created. + The result of this is to create the foreign tables in PostgreSQL. + + Note that we cannot add comments, foreign keys, or other constraints here. + + This method is called once per schema. + """ + try: + # Put these here for better error handling + from .system import system_schema + from .entity import entity_schema + from .connection import connection_schema + from .descriptor import descriptor_schema + + logger.info(f"Importing schema {schema} with options: {options}") + if schema == "system": + return system_schema() + elif schema == "entity": + return entity_schema() + elif schema == "connection": + return connection_schema() + elif schema == "descriptor": + return descriptor_schema() + else: + raise ValueError(f"Unknown schema: {schema}") + except: + logger.exception( + f"Error importing schema {schema}: {sys.exc_info()[1]}") + flush_logs() + raise + logger.info(f"Schema {schema} imported successfully") + def __init__(self, fdw_options, fdw_columns): + """ + Initialize the FDW with the given options and columns, + which are generated by the import_schema method. + + Args: + fdw_options (dict): Options for the foreign table. + fdw_columns (dict): Columns for the foreign table. + """ super().__init__(fdw_options, fdw_columns) self._options = TableOptions.from_string(fdw_options) @@ -71,115 +107,179 @@ def __init__(self, fdw_options, fdw_columns): for name, col in fdw_columns.items()} logger.info("FDW initialized with options: %s", fdw_options) - def _normalize_row(self, columns, row: dict) -> dict: - """ - Normalize a row to ensure it has the correct types for PostgreSQL. - This is used to convert ApertureDB types to PostgreSQL types. - """ - result = {} - for col in columns: - if col not in row: - continue - type_ = self._columns[col].type - if type_ == "datetime": - value = row[col]["_date"] if row[col] else None - elif type_ == "json": - value = json.dumps(row[col]) - elif type_ == "blob": - value = row[col] - else: - value = row[col] - result[col] = value - return result + def execute(self, quals: List[Qual], columns: Set[str]) -> Generator[dict, None, None]: + """ Execute the FDW query with the given quals and columns. - def _get_as_format(self, quals) -> Optional[str]: - """ - Get the 'as_format' from the quals if it exists. - This is used to determine how to return image data. + Args: + quals (list): List of conditions to filter the results. + Note that filtering is optional because PostgreSQL will also filter the results. + columns (set): List of columns to return in the results. """ - for qual in quals: - if qual.field_name == "_as_format": - assert qual.operator == "=", f"Unexpected operator for _as_format: {qual.operator} Expected '='" - return qual.value - return None - def _get_operations(self, quals) -> Optional[List[dict]]: + start_time = datetime.now() + logger.info( + f"Executing FDW {self._options.table_name} with quals: {quals} and columns: {columns}") + + self._check_quals(quals, columns) + + query = self._get_query(quals, columns) + + query_blobs = self._get_query_blobs(quals, columns) + + n_results = 0 + total_elapsed_time = 0 + exhausted = False + n_queries = 0 + try: + while query: + n_queries += 1 + gen = self._get_query_results(query, query_blobs) + try: + while True: + row, blob = next(gen) + result = self._post_process_row( + quals, columns, row, blob) + + logger.debug( + f"Yielding row: {json.dumps({k: v[:10] if isinstance(v, str) else len(v) if isinstance(v, (bytes, list)) else v for k, v in row.items()}, indent=2)} blob {len(blob) if blob else None}" + ) + n_results += 1 + if n_results % 1000 == 0: + logger.info( + f"Yielded {n_results} results so far for FDW {self._options.table_name}") + yield result + except StopIteration as e: + response, elapsed_time = e.value # return value from _get_query_results + total_elapsed_time += elapsed_time.total_seconds() + + query = self._get_next_query(query, response) + exhausted = True + finally: + elapsed_time = datetime.now() - start_time + logger.info( + f"Executed FDW {self._options.table_name} with {n_results} results and {n_queries} queries in {total_elapsed_time:.2f} seconds in ADB, {elapsed_time.total_seconds():.2f} seconds in execute, {'exhausted' if exhausted else 'not exhausted'}.") + + def _check_quals(self, quals: List[Qual], columns: Set[str]) -> None: """ - Get the 'operations' from the quals if it exists. - This is used to determine what operations to perform on the image data. + Check the quals to ensure they are valid. """ - for qual in quals: - if qual.field_name == "_operations": - assert qual.operator == "=", f"Unexpected operator for _operations: {qual.operator} Expected '='" - operations = json.loads(qual.value) - for op in operations: - if not isinstance(op, dict): - raise ValueError( - f"Invalid operation format: {op}. Expected a dictionary.") - if "type" not in op: - raise ValueError( - f"Operation must have 'type': {op}") - if op["type"] not in self._options.operation_types: - raise ValueError( - f"Invalid operation type: {op['type']}. Expected one of {self._options.operation_types}") - return operations - return None - - def _get_query(self, columns: Set[str], blobs: bool, as_format: Optional[str], operations: Optional[List[dict]], batch_size: int) -> List[dict]: + for col in columns: + col_type = self._columns[col].type + if not self._columns[col].listable: # special column; apply checks + clauses = [qual for qual in quals if qual.field_name == col] + + if len(clauses) > 1: + raise ValueError( + f"Multiple WHERE clauses for {col} are not allowed, got {len(clauses)} clauses.") + for clause in clauses: + if col_type == "boolean": + if clause.operator not in ["=", "<>", "IS", "IS NOT"]: + raise ValueError( + f"WHERE clauses for boolean column {col} can only use operators '=', '<>', 'IS', or 'IS NOT', got {clause.operator}") + else: + if clause.operator not in ["="]: + raise ValueError( + f"WHERE clauses for non-boolean column {col} can only use operators '=', got {clause.operator}") + + def _get_query(self, + quals: List[Qual], + columns: Set[str], + ) -> List[dict]: """ Construct the query to execute against ApertureDB. This is used to build the query based on the columns and options. """ - query = [{ - self._options.command: { - **self._options.extra, - **({"results": {"list": list(columns)}} if columns else {}), - "batch": { - "batch_id": 0, - "batch_size": batch_size - }, - **({"blobs": True} if blobs else {}), - **({"as_format": as_format} if as_format else {}), - **({"operations": operations} if operations else {}), - } - }] + listable_columns = { + col for col in columns if self._columns[col].listable} + + modifying_columns = { + col for col in columns if self._columns[col].modify_command_body is not None} + + command_body = {} + + if listable_columns: + command_body["results"] = {"list": list(listable_columns)} + + # Apply table modification, e.g. with_class, set + if self._options.modify_command_body: + self._options.modify_command_body(command_body=command_body) + + # Apply column modifications + for qual in quals: + if qual.field_name in modifying_columns: + + n_clauses = len( + [qual2 for qual2 in quals if qual2.field_name == qual.field_name]) + if n_clauses > 1: + raise ValueError( + f"Multiple WHERE clauses for {qual.field_name} are not allowed, got {n_clauses} clauses.") + + value = self._convert_qual_value(qual) + column_options = self._columns[qual.field_name] + column_options.modify_command_body( + command_body=command_body, value=value) + + # Check whether anyone has added the blobs parameter + blobs = command_body.get("blobs", False) + batch_size = BATCH_SIZE_WITH_BLOBS if blobs else BATCH_SIZE + command_body["batch"] = { + "batch_id": 0, + "batch_size": batch_size + } + + query = [{self._options.command: command_body}] + logger.debug( + f"Constructed query: {compact_pretty_json(query, indent=2)}") + return query - def _get_next_query(self, query: List[dict], response: List[dict]) -> Optional[List[dict]]: + def _convert_qual_value(self, qual: Qual) -> Any: """ - Get the next query to execute based on the response from the previous query. - This is used to handle batching. + Convert qual value into an internal value, depending on type and operator. """ - if not response or len(response) != 1: - logger.warning( - f"No results found for query: {query} -> {response}") - return None + col_type = self._columns[qual.field_name].type + if col_type == "boolean": + value = qual.value == 't' + if qual.operator in ["<>", "IS NOT"]: + value = not value + elif col_type == "json": + try: + value = json.loads(qual.value) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON in {qual.field_name} clause: {e}") + else: + # TODO: Might be more conversions needed here + value = qual.value - if "batch" not in response[0][self._options.command]: - # Some commands (like FindConnection) don't handle batching, so we assume all results are returned at once. - logger.info( - f"Single batch found for query: {query} -> {response[:10]}") - return None + return value - batch_id = response[0][self._options.command]["batch"]["batch_id"] - total_elements = response[0][self._options.command]["batch"]["total_elements"] - end = response[0][self._options.command]["batch"]["end"] + def _get_query_blobs(self, quals: List[Qual], columns: Set[str]) -> List[bytes]: + query_blob_quals = [qual for qual in quals + if self._columns[qual.field_name].query_blobs is not None] - if end >= total_elements: # No more batches to process - return None + if query_blob_quals: + if len(query_blob_quals) > 1: + raise ValueError( + f"Multiple query blobs requested: {query_blob_quals}. Only one query blob is allowed per query.") - next_query = query.copy() - next_query[0][self._options.command]["batch"]["batch_id"] += 1 - return next_query + qual = query_blob_quals[0] + value = self._convert_qual_value(qual) + return self._columns[qual.field_name].query_blobs(value=value) - def _get_query_results(self, query: List[dict]) -> Generator[Tuple[dict, Optional[bytes]], None, List[dict]]: + return [] + + def _get_query_results(self, + query: List[dict], + query_blobs: List[bytes], + ) -> Generator[Tuple[dict, Optional[bytes]], None, List[dict]]: logger.debug(f"Executing query: {query}") start_time = datetime.now() - _, results, blobs = get_pool().execute_query(query) + _, results, response_blobs = get_pool().execute_query(query, query_blobs) elapsed_time = datetime.now() - start_time logger.info( - f"Query executed in {elapsed_time.total_seconds()} seconds. Results: {results}, Blobs: {len(blobs) if blobs else 0}") + f"Query executed in {elapsed_time.total_seconds()} seconds. Results: {results}, Blobs: {len(response_blobs) if response_blobs else 0}") if not results or len(results) != 1: logger.warning( @@ -189,111 +289,141 @@ def _get_query_results(self, query: List[dict]) -> Generator[Tuple[dict, Optiona result_objects = results[0].get( self._options.command, {}).get(self._options.result_field, []) - if not blobs: + if not response_blobs: for row in result_objects: yield row, None else: - for row, blob in zip_longest(result_objects, blobs): + for row, blob in zip_longest(result_objects, response_blobs): yield row or {}, blob - return results + return results, elapsed_time - def execute(self, quals, columns): - """ Execute the FDW query with the given quals and columns. + def _post_process_row(self, + quals: List[Qual], columns: Set[str], + row: dict, blob: Optional[bytes]) -> dict: + """ + Apply post-processing steps to the row before yielding it. - Args: - quals (list): List of conditions to filter the results. - Note that filtering is optional because PostgreSQL will also filter the results. - columns (set): List of columns to return in the results. + This includes converting types from AQL to SQL, + adding non-list columns, and any column post-processing. """ + row = row.copy() # Avoid modifying the original row - logger.info( - f"Executing FDW {self._options.type}/{self._options.class_} with quals: {quals} and columns: {columns}") + for qual in quals: + if self._columns[qual.field_name].post_process_results is not None: + value = self._convert_qual_value(qual) + self._columns[qual.field_name].post_process_results( + row=row, value=value, blob=blob) - blobs = self._options.blob_column is not None and self._options.blob_column in columns - list_columns = {col for col in columns if self._columns[col].listable} + # Normalize the row to ensure it has the correct types for PostgreSQL + row = self._normalize_row(columns, row) - batch_size = BATCH_SIZE_WITH_BLOBS if blobs else BATCH_SIZE - as_format = self._get_as_format(quals) - operations = self._get_operations(quals) + row = self._add_non_list_columns(quals, columns, row) - query = self._get_query( - columns=list_columns, - blobs=blobs, - as_format=as_format, - operations=operations, - batch_size=batch_size) + return row - n_results = 0 - while query: - gen = self._get_query_results(query) - try: - while True: - row, blob = next(gen) - - # Add blob to the row if it exists - if blobs: - row[self._options.blob_column] = blob - - # Add special columns if they exist so that Postgres doesn't filter out rows - if as_format: - row["_as_format"] = as_format - if operations: - row["_operations"] = operations - - result = self._normalize_row(columns, row) - - logger.debug( - f"Yielding row: {json.dumps({k: v[:10] if isinstance(v, str) else len(v) if isinstance(v, (bytes, list)) else v for k, v in row.items()}, indent=2)}" - ) - n_results += 1 - if n_results % 1000 == 0: - logger.info( - f"Yielded {n_results} results so far for FDW {self._options.type}/{self._options.class_}") - yield result - except StopIteration as e: - response = e.value # return value from _get_query_results - - query = self._get_next_query(query, response) + def _normalize_row(self, columns, row: dict) -> dict: + """ + Normalize a row to ensure it has the correct types for PostgreSQL. + This is used to convert ApertureDB types to PostgreSQL types. + """ + result = {} + for col in columns: + if col not in row: + continue + value = self._convert_adb_value(row[col], col) + result[col] = value + return result - logger.info( - f"Executed FDW {self._options.type}/{self._options.class_} with {n_results} results") + def _convert_adb_value(self, value, col: str) -> Any: + """ + Convert an ApertureDB value to a PostgreSQL-compatible value. + """ + col_type = self._columns[col].type + if col_type == "datetime": + value = value["_date"] if value else None + elif col_type == "json": + value = json.dumps(value) + elif col_type == "blob": + value = value + # elif col_type == "boolean": + # value = 't' if value else 'f' + else: + value = value + return value - @classmethod - def import_schema(cls, schema, srv_options, options, restriction_type, restricts): + def _add_non_list_columns(self, quals: List[Qual], columns: Set[str], row: dict) -> dict: """ - Import the schema from ApertureDB and return a list of TableDefinitions. - This method is called when the foreign data wrapper is created. - The result of this is to create the foreign tables in PostgreSQL. + Add non-list columns to the row based on the quals. + This is used to ensure that all requested columns are present in the result. - Note that we cannot add comments, foreign keys, or other constraints here. + This is necessary because PostgreSQL will not return rows that don't meet the quals, and it doesn't know that we're + using special columns to do magic. - This method is called once per schema. + So we just copy the qual constraints into the row. """ - try: - # Put these here for better error handling - from .system import system_schema - from .entity import entity_schema - from .connection import connection_schema - from .descriptor import descriptor_schema + logger.debug( + f"Adding non-list columns to row: {row}, quals: {quals}, columns: {columns}") + row = row.copy() # Avoid modifying the original row + for qual in quals: + if not self._columns[qual.field_name].listable: + assert qual.field_name not in row, f"Column {qual.field_name} should not be in the row. It is a non-list column." + logger.debug( + f"Adding non-list column {qual.field_name} with value {qual.value} to row" + ) + # This double-conversion is necessary because of negative qual operators like IS NOT and <>. + value = self._convert_qual_value(qual) + value = self._convert_adb_value(value, qual.field_name) + row[qual.field_name] = value + + return row - logger.info(f"Importing schema {schema} with options: {options}") - if schema == "system": - return system_schema() - elif schema == "entity": - return entity_schema() - elif schema == "connection": - return connection_schema() - elif schema == "descriptor": - return descriptor_schema() - else: - raise ValueError(f"Unknown schema: {schema}") - except: - logger.exception( - f"Error importing schema {schema}: {sys.exc_info()[1]}") - flush_logs() - raise - logger.info(f"Schema {schema} imported successfully") + def _get_next_query(self, query: List[dict], response: List[dict]) -> Optional[List[dict]]: + """ + Get the next query to execute based on the response from the previous query. + This is used to handle batching. + """ + if not response or len(response) != 1: + logger.warning( + f"No results found for query: {query} -> {response}") + return None + + if "batch" not in response[0][self._options.command]: + # Some commands (like FindConnection) don't handle batching, so we assume all results are returned at once. + logger.info( + f"Single batch found for query: {query} -> {response[:10]}") + return None + + batch_id = response[0][self._options.command]["batch"]["batch_id"] + total_elements = response[0][self._options.command]["batch"]["total_elements"] + end = response[0][self._options.command]["batch"]["end"] + + if end >= total_elements: # No more batches to process + return None + + next_query = query.copy() + next_query[0][self._options.command]["batch"]["batch_id"] += 1 + return next_query + + def explain(self, quals: List[Qual], columns: Set[str], sortkeys=None, verbose=False) -> Iterable[str]: + """ + Generate an EXPLAIN statement for the FDW query. + This is used to provide information about how the query will be executed. + """ + logger.info( + f"Explaining FDW {self._options.table_name} with quals: {quals} and columns: {columns}") + self._check_quals(quals, columns) + + result = [f"FDW: {self._options.table_name}"] + query = self._get_query(quals, columns) + result.append(f"AQL: {compact_pretty_json(query, indent=2)}") + query_blobs = self._get_query_blobs(quals, columns) + # This part isn't verbose, but can be much slower + if query_blobs and verbose: + result.append( + f"Query Blob: {len(query_blobs[0])} bytes - {query_blobs[0][:10]}... (truncated)") + + return result -print("FDW class defined successfully") +logger.info("FDW class defined successfully") diff --git a/apps/sql-server/fdw/fdw/column.py b/apps/sql-server/fdw/fdw/column.py new file mode 100644 index 00000000..5f4180bd --- /dev/null +++ b/apps/sql-server/fdw/fdw/column.py @@ -0,0 +1,183 @@ +from .common import Curry, TYPE_MAP +import logging +from pydantic import BaseModel +from typing import List, Dict, Any, Optional, Literal +from multicorn import ColumnDefinition +import json + +logger = logging.getLogger(__name__) + + +class ColumnOptions(BaseModel): + """ + Options passed to the foreign table columns from `import_schema`. + """ + count: Optional[int] = None # number of matched objects for this column + indexed: bool = False # whether the column is indexed + # AQL type of the column: "string", "number", "boolean", "json", "blob" + type: Optional[Literal["string", + "number", "boolean", "json", "blob", "datetime"]] = None + listable: bool = True # whether the column can be passed to results/list + unique: bool = False # whether the column is unique, used for _uniqueid + + # These three hooks are used to provide special handling for certain columns. + # All of them are invoked when the column is used in a qual with an equality operator. + # All are passed the qual value as `value` and may take other keyword arguments. + # Always use the `Curry` class to pass these functions as options. + # This ensues they are serialized correctly and can be executed later. + + # modify_command_body: also passed `command_body` and expected to modify in place + modify_command_body: Optional[Curry] = None + + # query_blobs: returns a list of bytes + query_blobs: Optional[Curry] = None + + # post_process_results: also passed `row` (before type normalization) and expected to modify in place + post_process_results: Optional[Curry] = None + + def model_post_init(self, context: Any): + """ + Validate the options after model initialization. + """ + if self.listable and not self.type: + raise ValueError("listable columns must have a type defined") + + # Check that hooks have valid function signatures + if self.modify_command_body: + self.modify_command_body.validate_signature( + {"value", "command_body"}) + if self.query_blobs: + self.query_blobs.validate_signature({"value"}) + if self.post_process_results: + self.post_process_results.validate_signature( + {"value", "row", "blob"}) + + @classmethod + def from_string(cls, options_str: Dict[str, str]) -> "ColumnOptions": + """ + Create a ColumnOptions instance from a string dictionary. + This is used to decode options from the foreign table column definition. + """ + options = json.loads(options_str["column_options"]) + return cls(**options) + + def to_string(self) -> Dict[str, str]: + """ + Convert ColumnOptions to a string dictionary. + This is used to encode options for the foreign table column definition. + """ + return {"column_options": json.dumps(self.dict(), default=str)} + + # Reject any extra fields that are not defined in the model. + model_config = { + "extra": "forbid" + } + + +# Some utility functions for Curry hooks + + +def passthrough(name: str, + value: Any, command_body: Dict[str, Any]) -> None: + """ + A ColumnOptions modify_command_body hook. + + Adds the value to the command body under the given name. + """ + command_body[name] = value + + +def add_blob(column: str, + value: Any, row: dict, blob: Optional[bytes]) -> None: + """ + A ColumnOptions post_process_results hook. + + Adds the value to the row under the given column name, + if the value is true. + """ + assert isinstance(value, bool), \ + f"Expected value to be a boolean, got {type(value)}" + if value: + row[column] = blob + +# Some utility functions for creating column defintions + + +def blob_columns(column: str) -> List[ColumnDefinition]: + """ + Constructs column definitions for object types that can contain blobs. + + This approach allows us to control whether the blobs are returned + based on the query, while still providing a column for the blob data. + In particular, the boolean column avoids the awkwardness of including the blob data in "SELECT *". + + Args: + column (str): The name of the column to use for blobs, e.g. _blob + + Returns: + column_definitions: A list of two column definitions: + - A boolean column `_blobs` indicating if the query should return blobs + - A blob column for the actual blob data + """ + return [ + ColumnDefinition( + column_name="_blobs", + type_name="boolean", + options=ColumnOptions( + type="boolean", + listable=False, + modify_command_body=Curry(passthrough, "blobs"), + post_process_results=Curry(add_blob, column=column) + ).to_string()), + ColumnDefinition( + column_name=column, + type_name="bytea", + options=ColumnOptions( + type="blob", + listable=False, + ).to_string()) + ] + + +def property_columns(data: dict) -> List[ColumnDefinition]: + """ + Create a list of ColumnDefinitions for the given properties. + This is used to create the foreign table in PostgreSQL. + """ + columns = [] + if "properties" in data and data["properties"] is not None: + assert isinstance(data["properties"], dict), \ + f"Expected properties to be a dict, got {type(data['properties'])}" + for prop, prop_data in data["properties"].items(): + try: + count, indexed, type_ = prop_data + columns.append(ColumnDefinition( + column_name=prop, + type_name=TYPE_MAP[type_.lower()], + options=ColumnOptions( + count=count, + indexed=indexed, + type=type_.lower(), + ).to_string()) + ) + except Exception as e: + logger.exception( + f"Error processing property '{prop}': {e}") + raise + + columns.append(uniqueid_column(data.get("matched", 0))) + + return columns + + +def uniqueid_column(count: int = 0) -> ColumnDefinition: + """ Create a ColumnDefinition for the _uniqueid column. """ + return ColumnDefinition( + column_name="_uniqueid", + type_name="text", + options=ColumnOptions( + count=count, + indexed=True, + unique=True, + type="string" + ).to_string()) diff --git a/apps/sql-server/fdw/fdw/common.py b/apps/sql-server/fdw/fdw/common.py index ee4fc1dd..76ea4113 100644 --- a/apps/sql-server/fdw/fdw/common.py +++ b/apps/sql-server/fdw/fdw/common.py @@ -1,16 +1,38 @@ -import json -import os -import sys -import logging -from dotenv import load_dotenv -from multicorn import ColumnDefinition -from typing import List, Optional, Dict +from typing import Callable, Any +from pydoc import locate +from pydantic import BaseModel, GetCoreSchemaHandler, GetJsonSchemaHandler, TypeAdapter from collections import defaultdict -from pydantic import BaseModel +from typing import List, Optional, Dict, Callable, Any +from multicorn import ColumnDefinition +from dotenv import load_dotenv +import logging +import sys +import os +import json +from contextlib import contextmanager +from pydantic_core import core_schema +import inspect logger = logging.getLogger(__name__) +@contextmanager +def import_path(path): + original = list(sys.path) + sys.path.insert(0, path) + try: + yield + finally: + sys.path = original + + +@contextmanager +def import_from_app(): + """We don't control the import path, so we need to ensure the app directory is in the path.""" + with import_path('/app'): + yield + + def load_aperturedb_env(path="/app/aperturedb.env"): """Load environment variables from a file. This is used because FDW is executed in a "secure" environment where @@ -21,8 +43,8 @@ def load_aperturedb_env(path="/app/aperturedb.env"): load_dotenv(dotenv_path=path, override=True) -_POOL = None # Global connection pool -_SCHEMA = None # Global schema variable +_POOL = None # Global connection pool; see get_pool() +_SCHEMA = None # Global schema variable; see get_schema() def get_log_level() -> int: @@ -35,10 +57,10 @@ def get_log_level() -> int: def get_pool() -> "ConnectionPool": """Get the global connection pool. Lazy initialization.""" load_aperturedb_env() - sys.path.append('/app') - from connection_pool import ConnectionPool global _POOL if _POOL is None: + with import_from_app(): + from connection_pool import ConnectionPool _POOL = ConnectionPool() logger.info("Connection pool initialized") return _POOL @@ -65,93 +87,135 @@ def get_schema() -> Dict: } -def property_columns(data: dict) -> List[ColumnDefinition]: - """ - Create a list of ColumnDefinitions for the given properties. - This is used to create the foreign table in PostgreSQL. - """ - columns = [] - if "properties" in data and data["properties"] is not None: - assert isinstance(data["properties"], dict), \ - f"Expected properties to be a dict, got {type(data['properties'])}" - for prop, prop_data in data["properties"].items(): - try: - count, indexed, type_ = prop_data - columns.append(ColumnDefinition( - column_name=prop, - type_name=TYPE_MAP[type_.lower()], - options=ColumnOptions(count=count, indexed=indexed, type=type_.lower()).to_string())) - except Exception as e: - logger.exception( - f"Error processing property '{prop}': {e}") - raise - - # Add the _uniqueid column - columns.append(ColumnDefinition( - column_name="_uniqueid", type_name="text", options=ColumnOptions(count=data.get("matched", 0), indexed=True, unique=True, type="string").to_string())) - - return columns - - -class TableOptions(BaseModel): +logger = logging.getLogger(__name__) + + +class Curry: """ - Options passed to the foreign table from `import_schema`. + This class is used to wrap functions so that they can be serialized and deserialized as text. + The function is stored as a reference to its module and qualified name, + which basically means that it has to be a named top-level function in a module. + Positional and keyword arguments can be supplied to give the effect of a function closure. + These must be JSON-serializable. + + The reason for all this is that we want to be able to store functions in TableOptions and ColumnOptions, + in order to specialize the behaviour of the `execute` method, + but those datastructures are serialized to JSON in `import_schema`, + because Postgres stores them in the database as text fields, + and then later passes them to `FDW.__init__`, which deserializes them back into Python objects. """ - class_: Optional[str] = None # class name of the entity or connection as reported by GetSchema - type: str = "entity" # object type, e.g. "entity", "connection", "descriptor" - count: int = 0 # number of matched objects - # command to execute, e.g. "FindEntity", "FindConnection", etc. - command: str = "FindEntity" - # field to look for in the response, e.g. "entities", "connections" - result_field: str = "entities" - extra: dict = {} # additional options for the command - blob_column: Optional[str] = None # column containing blob data - # properties of the descriptor set, if applicable - descriptor_set_properties: Optional[dict] = None - # operation types for the descriptor set, if applicable - operation_types: Optional[List[str]] = None + + def __init__(self, func: Callable, *args, **kwargs): + self.func = func + self.args = args + self.kwargs = kwargs + + if not hasattr(func, "__module__") or not hasattr(func, "__qualname__"): + raise TypeError( + f"Expected a function or method with __module__ and __qualname__, " + f"got object of type {type(func)}: {repr(func)}" + ) + + def validate_signature(self, required_keywords: set): + sig = inspect.signature(self.func) + param_names = set(sig.parameters.keys()) + accepts_kwargs = any( + p.kind == inspect.Parameter.VAR_KEYWORD + for p in sig.parameters.values() + ) + + if not accepts_kwargs and not required_keywords <= param_names: + missing = required_keywords - param_names + raise TypeError( + f"Function {self.func} missing required keywords: {missing}") + + overlap = required_keywords & self.kwargs.keys() + if overlap: + raise TypeError( + f"Curry for {self.func} should not override required args: {overlap}") + + def __call__(self, **kwargs): + return self.func(*self.args, **self.kwargs, **kwargs) + + def to_json(self): + return { + "__curry__": True, + "module": self.func.__module__, + "qualname": self.func.__qualname__, + "args": self.args, + "kwargs": self.kwargs, + } @classmethod - def from_string(cls, options_str: Dict[str, str]) -> "TableOptions": - """ - Create a TableOptions instance from a string dictionary. - This is used to decode options from the foreign table definition. - Postgres restricts options to be a string-valued dictionary. - """ - options = json.loads(options_str["table_options"]) - return cls(**options) - - def to_string(self) -> Dict[str, str]: - """ - Convert TableOptions to a string dictionary. - This is used to encode options for the foreign table definition. - """ - return {"table_options": json.dumps(self.dict(), default=str)} - - -class ColumnOptions(BaseModel): - """ - Options passed to the foreign table columns from `import_schema`. - """ - count: Optional[int] = None # number of matched objects for this column - indexed: bool = False # whether the column is indexed - type: str # AQL type of the column: "string", "number", "boolean", "json", "blob" - # whether the column has special meaning (e.g. _blob, _image) - listable: bool = True # whether the column can be passed to results/list - unique: bool = False # whether the column is unique, used for _uniqueid + def from_json(cls, data: dict): + if not data.get("__curry__"): + raise ValueError("Invalid Curry JSON data") + module = data["module"] + qualname = data["qualname"] + args = data.get("args", []) + kwargs = data.get("kwargs", {}) + func = locate(f"{module}.{qualname}") + if func is None: + raise ValueError(f"Could not locate function {module}.{qualname}") + return cls(func, *args, **kwargs) + + @classmethod + def _validate(cls, value: Any) -> "Curry": + logger.debug(f"Validating Curry: {value}") + if isinstance(value, cls): + return value + elif isinstance(value, dict): + return cls.from_json(value) + raise TypeError(f"Cannot convert {value!r} to Curry") + + @classmethod + def _serialize(cls, value: "Curry") -> dict: + return value.to_json() @classmethod - def from_string(cls, options_str: Dict[str, str]) -> "ColumnOptions": - """ - Create a ColumnOptions instance from a string dictionary. - This is used to decode options from the foreign table column definition. - """ - options = json.loads(options_str["column_options"]) - return cls(**options) - - def to_string(self) -> Dict[str, str]: - """ - Convert ColumnOptions to a string dictionary. - This is used to encode options for the foreign table column definition. - """ - return {"column_options": json.dumps(self.dict(), default=str)} + def __get_pydantic_core_schema__(cls, source_type, handler): + return core_schema.no_info_after_validator_function( + cls._validate, + core_schema.any_schema(), + serialization=core_schema.plain_serializer_function_ser_schema( + cls._serialize, when_used="always" + ) + ) + + def __repr__(self): + return f"Curry({self.func.__module__}.{self.func.__qualname__}, args={self.args}, kwargs={self.kwargs})" + + +def compact_pretty_json(data: Any, line_length=78, level=0, indent=2) -> str: + """ + Compact yet pretty JSON representation of the data. + If it fits in one line (accounting for indentation), it stays one line. + Otherwise, falls back to a multi-line indented format. + """ + prefix = " " * (level * indent) + one_line = json.dumps(data, ensure_ascii=False) + + if len(prefix) + len(one_line) <= line_length: + return prefix + one_line + + if isinstance(data, (list, tuple)): + lines = [prefix + "["] + for item in data: + lines.append(compact_pretty_json( + item, line_length, level + 1, indent)) + lines.append(prefix + "]") + return "\n".join(lines) + + elif isinstance(data, dict): + lines = [prefix + "{"] + for i, (key, value) in enumerate(data.items()): + key_str = json.dumps(key, ensure_ascii=False) + ": " + val_str = compact_pretty_json( + value, line_length, level + 1, indent) + lines.append(" " * ((level + 1) * indent) + + key_str + val_str.lstrip()) + lines.append(prefix + "}") + return "\n".join(lines) + + else: + return prefix + json.dumps(data, ensure_ascii=False) diff --git a/apps/sql-server/fdw/fdw/connection.py b/apps/sql-server/fdw/fdw/connection.py index 8be34428..da5ae7e1 100644 --- a/apps/sql-server/fdw/fdw/connection.py +++ b/apps/sql-server/fdw/fdw/connection.py @@ -1,6 +1,16 @@ +# This module populates the connection schema for ApertureDB. +# This schema contains a table for each connection class. +# Hence every AQL query includes `with_class`. +# In addition to the usual _uniqueid, tables have columns _src and _dst. +# +# SELECT _uniqueid, _src, _dst +# FROM "WorkflowCreated"; + import logging from typing import List -from .common import property_columns, get_schema, TableOptions, ColumnOptions +from .common import get_schema, Curry +from .column import property_columns, ColumnOptions +from .table import TableOptions, literal from multicorn import TableDefinition, ColumnDefinition logger = logging.getLogger(__name__) @@ -32,11 +42,11 @@ def connection_table(connection: str, data: dict) -> TableDefinition: table_name = connection options = TableOptions( - class_=connection, - type="connection", + table_name=f'connection."{table_name}"', count=data.get("matched", 0), command="FindConnection", result_field="connections", + modify_command_body=Curry(literal, {"with_class": connection}), ) columns = [] @@ -46,9 +56,21 @@ def connection_table(connection: str, data: dict) -> TableDefinition: # Add the _src, and _dst columns columns.append(ColumnDefinition( - column_name="_src", type_name="text", options=ColumnOptions(class_=data["src"], count=data.get("matched", 0), indexed=True, type="string").to_string())) + column_name="_src", + type_name="text", + options=ColumnOptions( + count=data.get("matched", 0), + indexed=True, + type="string", + ).to_string())) columns.append(ColumnDefinition( - column_name="_dst", type_name="text", options=ColumnOptions(class_=data["dst"], count=data.get("matched", 0), indexed=True, type="string").to_string())) + column_name="_dst", + type_name="text", + options=ColumnOptions( + count=data.get("matched", 0), + indexed=True, + type="string", + ).to_string())) except Exception as e: logger.exception( f"Error processing properties for connection {connection}: {e}") @@ -60,5 +82,4 @@ def connection_table(connection: str, data: dict) -> TableDefinition: return TableDefinition( table_name=table_name, columns=columns, - options=options.to_string() - ) + options=options.to_string()) diff --git a/apps/sql-server/fdw/fdw/descriptor.py b/apps/sql-server/fdw/fdw/descriptor.py index 52a4561a..b002ce12 100644 --- a/apps/sql-server/fdw/fdw/descriptor.py +++ b/apps/sql-server/fdw/fdw/descriptor.py @@ -1,26 +1,27 @@ +# This module populates the descriptor schema for ApertureDB. +# This schema contains a table for each descriptor set, +# and hence every AQL query includes `set`. +# These tables support find-similar queries. +# +# SELECT * FROM descriptor."crawl-to-rag" +# WHERE _find_similar = FIND_SIMILAR(text := 'find entity', k := 10) +# AND _blobs + from multicorn import TableDefinition, ColumnDefinition from typing import List -from .common import property_columns, get_schema, get_pool, TableOptions +from .common import get_schema, get_pool, import_from_app, Curry +from .column import property_columns, ColumnOptions, blob_columns +from .table import TableOptions, literal import logging +import numpy as np +import json +from datetime import datetime -logger = logging.getLogger(__name__) +with import_from_app(): + from embeddings import Embedder -def get_descriptor_sets() -> dict: - """ - Get the descriptor sets from the environment variable. - This is used to create the foreign tables for descriptor sets. - """ - query = [{"FindDescriptorSet": {"results": {"all_properties": True}, - "counts": True, "engines": True, "dimensions": True, "metrics": True}}] - _, response, _ = get_pool().execute_query(query) - if "entities" not in response[0]["FindDescriptorSet"]: - return {} - - results = { - e['_name']: e for e in response[0]["FindDescriptorSet"]["entities"] - } - return results +logger = logging.getLogger(__name__) def descriptor_schema() -> List[TableDefinition]: @@ -37,19 +38,52 @@ def descriptor_schema() -> List[TableDefinition]: for name, properties in descriptor_sets.items(): table_name = name + # We switch the "find similar" feature on if the descriptor set + # has properties that allow us to find the correct embedding model. + # Notionally we could allow direct vector queries regardless, but + # this is a good heuristic to avoid unnecessary complexity. + find_similar = Embedder.check_properties(properties) + options = TableOptions( - class_="_Descriptor", - type="entity", + table_name=f'descriptor."{table_name}"', count=properties["_count"], command="FindDescriptor", result_field="entities", - extra={"set": name}, - descriptor_set_properties=properties, + modify_command_body=Curry( + literal, {"set": name, "distances": True}), ) - # TODO: We're giving all tables the same columns, which is not ideal. - columns = property_columns(get_schema().get( - "entities", {}).get("classes", {}).get("_Descriptor", {})) + columns = property_columns_for_descriptors_in_set(name) + + if find_similar: + columns.append(ColumnDefinition( + column_name="_find_similar", + type_name="JSONB", + options=ColumnOptions( + listable=False, + modify_command_body=Curry( + find_similar_modify_command_body), + query_blobs=Curry(find_similar_query_blobs, + properties=properties, + descriptor_set=name), + ).to_string() + )) + + # This special column has a parameter, but is also listable, + # but does not appear in the schema. + columns.append(ColumnDefinition( + column_name="_distance", + type_name="double precision", + options=ColumnOptions( + listable=True, + type="number", + ).to_string() + )) + + # Special field _label has a parameter, but is also listable, + # and does appear in the schema, so we skip it here. + + columns.extend(blob_columns("_vector")) logger.debug( f"Creating table {table_name} with options {options.to_string()} and columns {columns}") @@ -60,4 +94,152 @@ def descriptor_schema() -> List[TableDefinition]: options=options.to_string() ) results.append(table) + return results + + +def get_descriptor_sets() -> dict: + """ + Get the descriptor sets from the environment variable. + This is used to create the foreign tables for descriptor sets. + """ + query = [{"FindDescriptorSet": {"results": {"all_properties": True}, + "counts": True, "engines": True, "dimensions": True, "metrics": True}}] + _, response, _ = get_pool().execute_query(query) + if "entities" not in response[0]["FindDescriptorSet"]: + return {} + + results = { + e['_name']: e for e in response[0]["FindDescriptorSet"]["entities"] + } + return results + + +def property_columns_for_descriptors_in_set(name: str) -> List[ColumnDefinition]: + """ + Get the property columns for a specific descriptor set. + """ + query = [{ + "FindDescriptor": { + "set": name, + "_ref": 1 + } + }, { + "GetSchema": { + "ref": 1 + } + }] + + _, response, _ = get_pool().execute_query(query) + + properties = response[1]["GetSchema"].get( + "entities", {}).get("classes", {}).get("_Descriptor", {}) + + columns = property_columns(properties) + + return columns + + +def find_similar_modify_command_body( + value: str, command_body: dict) -> None: + """ + Modify the command body for a find similar query. + + Args: + value: JSON string generated from the FIND_SIMILAR SQL function + command_body: The command body to modify + + Side Effects: + Modifies the command body in place to include the find similar parameters. + """ + try: + find_similar = json.loads(value) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON for _find_similar: {value}") from e + + logger.debug(f"find_similar: {find_similar}") + + if not isinstance(find_similar, dict): + raise ValueError( + f"Invalid find_similar format: {find_similar}. Expected an object.") + + include_list = {"k_neighbors", "knn_first"} + extra = {k: v for k, v in find_similar.items( + ) if k in include_list and v is not None} + command_body.update(extra) + + +def find_similar_query_blobs( + properties: dict, descriptor_set: str, + value: str) -> List[bytes]: + """ + Generates vector data for find similar operations. + + Args: + properties: The properties of the descriptor set. + descriptor_set: The name of the descriptor set. + value: JSON string generated from the FIND_SIMILAR SQL function + + Returns: + blobs: list of length one containing the vector data as bytes + """ + try: + find_similar = json.loads(value) + except json.JSONDecodeError as e: + raise ValueError( + f"Invalid JSON for _find_similar: {value}") from e + + logger.debug(f"find_similar: {find_similar}") + + if not isinstance(find_similar, dict): + raise ValueError( + f"Invalid find_similar format: {find_similar}. Expected an object.") + + if "vector" in find_similar and find_similar["vector"] is not None: + raw_vector = find_similar["vector"] + dimensions = properties["_dimensions"] + # Validate that raw_vector is a list or tuple of the correct length and numeric type + if not isinstance(raw_vector, (list, tuple)): + raise ValueError( + f"Invalid vector type: {type(raw_vector)}. Expected list or tuple.") + if len(raw_vector) != dimensions: + raise ValueError( + f"Invalid vector length: {len(raw_vector)}. Expected {dimensions}.") + if not all(isinstance(x, (int, float)) for x in raw_vector): + raise ValueError( + f"Invalid vector contents: {raw_vector}. All elements must be numeric (int or float).") + vector = np.array(raw_vector, dtype=np.float32) + else: + # This takes ~7s the first time, but ~1s on subsequent calls because of a file cache. + # Could consider caching the embedder and maybe even doing cache warmup, + # but the peculiar invocation environment of Python within PostgreSQL + # makes this tricky, because we can't consistently persist state across calls. + start_time = datetime.now() + embedder = Embedder.from_properties( + properties=properties, + descriptor_set=descriptor_set, + ) + elapsed_time = datetime.now() - start_time + logger.debug( + f"Creating embedder took {elapsed_time.total_seconds()} seconds for descriptor set {descriptor_set}") + + start_time = datetime.now() + if "text" in find_similar and find_similar["text"] is not None: + text = find_similar["text"] + vector = embedder.embed_text(text) + elif "image" in find_similar and find_similar["image"] is not None: + image = find_similar["image"] + vector = embedder.embed_image(image) + else: + raise ValueError( + "find_similar must have one of 'text', 'image', or 'vector' to embed.") + elapsed_time = datetime.now() - start_time + logger.debug( + f"Embedding took {elapsed_time.total_seconds()} seconds for descriptor set {descriptor_set}") + + # Use .data.tobytes() to avoid unnecessary copying if the array is contiguous + if not vector.flags['C_CONTIGUOUS']: + vector = np.ascontiguousarray(vector, dtype=np.float32) + blob = vector.data.tobytes() + return [blob] # Return as a list of one blob diff --git a/apps/sql-server/fdw/fdw/entity.py b/apps/sql-server/fdw/fdw/entity.py index cb850424..2790c6ce 100644 --- a/apps/sql-server/fdw/fdw/entity.py +++ b/apps/sql-server/fdw/fdw/entity.py @@ -1,11 +1,36 @@ +# This module populates the entity schema for ApertureDB. +# This schema contains a table for each entity class. +# Hence every AQL query includes `with_class`. +# +# SELECT * FROM "CrawlDocument" LIMIT 10; + from multicorn import TableDefinition -from .common import property_columns, get_schema, TableOptions +from .common import get_schema, Curry +from .column import property_columns +from .table import TableOptions, literal from typing import List import logging logger = logging.getLogger(__name__) +def entity_schema() -> List[TableDefinition]: + """ + Return the entity schema for ApertureDB. + This is used to create the foreign tables for entity classes. + """ + logger.info("Creating entity schema") + results = [] + schema = get_schema() + if "entities" in schema and "classes" in schema["entities"]: + assert isinstance(schema["entities"]["classes"], dict), \ + f"Expected entities.classes to be a dict, got {type(schema['entities']['classes'])}" + for entity, data in schema["entities"]["classes"].items(): + if entity[0] != "_": + results.append(entity_table(entity, data)) + return results + + def entity_table(entity: str, data: dict) -> TableDefinition: """ Create a TableDefinition for an entity. @@ -18,12 +43,11 @@ def entity_table(entity: str, data: dict) -> TableDefinition: table_name = entity options = TableOptions( - class_=entity, - type="entity", - matched=data.get("matched", 0), - extra={"with_class": entity}, + table_name=f'entity."{table_name}"', + count=data.get("matched", 0), command="FindEntity", result_field="entities", + modify_command_body=Curry(literal, {"with_class": entity}), ) columns = [] @@ -43,20 +67,3 @@ def entity_table(entity: str, data: dict) -> TableDefinition: columns=columns, options=options.to_string() ) - - -def entity_schema() -> List[TableDefinition]: - """ - Return the entity schema for ApertureDB. - This is used to create the foreign tables for entity classes. - """ - logger.info("Creating entity schema") - results = [] - schema = get_schema() - if "entities" in schema and "classes" in schema["entities"]: - assert isinstance(schema["entities"]["classes"], dict), \ - f"Expected entities.classes to be a dict, got {type(schema['entities']['classes'])}" - for entity, data in schema["entities"]["classes"].items(): - if entity[0] != "_": - results.append(entity_table(entity, data)) - return results diff --git a/apps/sql-server/fdw/fdw/system.py b/apps/sql-server/fdw/fdw/system.py index 1069139d..ca857efb 100644 --- a/apps/sql-server/fdw/fdw/system.py +++ b/apps/sql-server/fdw/fdw/system.py @@ -1,5 +1,14 @@ -from typing import List, Literal -from .common import property_columns, get_schema, TYPE_MAP, TableOptions, ColumnOptions +# This module populates the system schema for ApertureDB. +# This schema contains tables for system classes such as _Blob, _Image, etc. +# Note that the tables for these are named without the leading underscore. +# It also includes special tables for Entity and Connection. +# +# SELECT * FROM "DescriptorSet" LIMIT 10; + +from typing import List, Literal, Set, Any, Dict +from .common import get_schema, TYPE_MAP, Curry +from .column import property_columns, ColumnOptions, blob_columns, uniqueid_column, passthrough +from .table import TableOptions, literal from multicorn import TableDefinition, ColumnDefinition import logging from collections import defaultdict @@ -33,6 +42,77 @@ def system_schema() -> List[TableDefinition]: return results +def operations_column(types: Set[str]) -> ColumnDefinition: + return ColumnDefinition( + column_name="_operations", + type_name="jsonb", + options=ColumnOptions( + type="json", + listable=False, + modify_command_body=Curry(operations_passthrough, types=types), + ).to_string() + ) + + +def operations_passthrough(types: Set[str], + value: Any, command_body: Dict[str, Any]) -> None: + """ + This is a ColumnOptions modify_command_body hook. + + Pass through the operations from the query to the command body. + Includes JSON conversion and some validation. + + Error messages are SQL-oriented. + """ + operations = value + if not isinstance(operations, list): + raise ValueError( + f"Operations must be an array, got {operations} {type(operations)}") + + for op in operations: + if not isinstance(op, dict): + raise ValueError( + f"Invalid operation format: {op}. Expected an object.") + if "type" not in op: + raise ValueError( + f"Operation must have 'type' field: {op}") + if op["type"] not in types: + raise ValueError( + f"Invalid operation type: {op['type']}. Expected one of {types}") + + command_body["operations"] = operations + + +def blob_extra_columns() -> List[ColumnDefinition]: + """ + Returns the columns for a _Blob entity. + """ + return blob_columns("_blob") + + +def image_extra_columns() -> List[ColumnDefinition]: + """ + Returns the columns for a _Image entity. + """ + return blob_columns("_image") + [ + ColumnDefinition( + column_name="_as_format", + type_name="text", + options=ColumnOptions( + listable=False, + modify_command_body=Curry(passthrough, "as_format"), + ).to_string() + ), + operations_column({"threshold", "resize", "crop", "rotate", "flip"}), + ] + + +OBJECT_COLUMNS_HANDLERS = { + "_Blob": blob_extra_columns, + "_Image": image_extra_columns, +} + + def system_table(entity: str, data: dict) -> TableDefinition: """ Create a TableDefinition for a system object. @@ -50,42 +130,14 @@ def system_table(entity: str, data: dict) -> TableDefinition: table_name = entity[1:] options = TableOptions( - class_=entity, - type=entity, - matched=data.get("matched", 0), + table_name=f'system."{table_name}"', + count=data.get("matched", 0), command=f"Find{entity[1:]}", # e.g. FindBlob, FindImage, etc. result_field="entities", ) - # Blob-like entities get special columns specific to the type - if entity == "_Blob": - # _Blob gets _blob column - columns.append(ColumnDefinition( - column_name="_blob", type_name="bytea", - options=ColumnOptions( - count=data.get("matched", 0), indexed=False, type="blob", listable=False).to_string()) - ) - options.blob_column = "_blob" - elif entity == "_Image": - # _Image gets _image, _as_format, _operations columns - columns.append(ColumnDefinition( - column_name="_image", type_name="bytea", - options=ColumnOptions( - count=data.get("matched", 0), indexed=False, type="blob", listable=False).to_string()) - ) - columns.append(ColumnDefinition( - column_name="_as_format", type_name="image_format_enum", - options=ColumnOptions( - count=data.get("matched", 0), indexed=False, type="string", listable=False).to_string()) - ) - columns.append(ColumnDefinition( - column_name="_operations", type_name="jsonb", - options=ColumnOptions( - count=data.get("matched", 0), indexed=False, type="json", listable=False).to_string()) - ) - options.blob_column = "_image" - options.operation_types = ["threshold", - "resize", "crop", "rotate", "flip"] + if entity in OBJECT_COLUMNS_HANDLERS: + columns.extend(OBJECT_COLUMNS_HANDLERS[entity]()) except Exception as e: logger.exception( @@ -112,17 +164,20 @@ def system_entity_table() -> TableDefinition: """ table_name = "Entity" + count = sum( + data.get("matched", 0) for class_, data in get_schema().get("entities", {}).get("classes", {}).items() if class_[0] != "_" + ) + options = TableOptions( - type="entity", + table_name=f'system."{table_name}"', + count=count, command=f"FindEntity", result_field="entities", ) - columns = [ - ColumnDefinition( - column_name="_uniqueid", type_name="text", - options=ColumnOptions(count=0, indexed=True, unique=True, type="string").to_string()), - ] + columns = [] + + columns.append(uniqueid_column(count)) columns.extend(get_consistent_properties("entities")) @@ -146,25 +201,39 @@ def system_connection_table() -> TableDefinition: """ table_name = "Connection" + count = sum( + data.get("matched", 0) for class_, data in get_schema().get("connections", {}).get("classes", {}).items() + ) + options = TableOptions( - type="connection", + table_name=f'system."{table_name}"', + count=count, command="FindConnection", result_field="connections", ) - columns = [ - ColumnDefinition( - column_name="_uniqueid", type_name="text", - options=ColumnOptions(count=0, indexed=True, unique=True, type="string").to_string()), - ] + columns = [] + + columns.append(uniqueid_column(count)) columns.extend(get_consistent_properties("connections")) # Add the _src, and _dst columns columns.append(ColumnDefinition( - column_name="_src", type_name="text", options=ColumnOptions(indexed=True, type="string").to_string())) + column_name="_src", + type_name="text", + options=ColumnOptions( + indexed=True, + type="string", + ).to_string())) + columns.append(ColumnDefinition( - column_name="_dst", type_name="text", options=ColumnOptions(indexed=True, type="string").to_string())) + column_name="_dst", + type_name="text", + options=ColumnOptions( + indexed=True, + type="string", + ).to_string())) logger.debug( f"Creating system connection table as {table_name} with columns: {columns} and options: {options}") @@ -174,13 +243,14 @@ def system_connection_table() -> TableDefinition: columns=columns, options=options.to_string() ) + logger.debug(f"System connection table created: {result}") return result def get_consistent_properties(type_: Literal["entities", "connections"]) -> List[ColumnDefinition]: """ - Get column definitions for properties that are consistent across all classes of a given type. + Get column definitions for properties that are consistently typed across all classes of a given type. """ property_types = defaultdict(set) schema = get_schema() diff --git a/apps/sql-server/fdw/fdw/table.py b/apps/sql-server/fdw/fdw/table.py new file mode 100644 index 00000000..67fb6911 --- /dev/null +++ b/apps/sql-server/fdw/fdw/table.py @@ -0,0 +1,69 @@ +from .common import Curry +import logging +from pydantic import BaseModel +from typing import List, Dict, Any, Optional + + +logger = logging.getLogger(__name__) + + +class TableOptions(BaseModel): + """ + Options passed to the foreign table from `import_schema`. + """ + # name of the table in PostgreSQL + table_name: str + # number of objects, probably from "matched" field in GetSchema response + count: int = 0 + # command to execute, e.g. "FindEntity", "FindConnection", etc. + command: str = "FindEntity" + # field to look for in the response, e.g. "entities", "connections" + result_field: str = "entities" + + # This hook is used to modify the command body before executing it. + # It is passed the command body as `command_body`. + # It should modify the command body in place. + modify_command_body: Optional[Curry] = None + + def model_post_init(self, context: Any): + """ + Validate the options after model initialization. + """ + # Check that modify_command_body has a valid function signature + if self.modify_command_body: + self.modify_command_body.validate_signature({"command_body"}) + + @classmethod + def from_string(cls, options_str: Dict[str, str]) -> "TableOptions": + """ + Create a TableOptions instance from a string dictionary. + This is used to decode options from the foreign table definition. + Postgres restricts options to be a string-valued dictionary. + """ + return cls.model_validate_json(options_str["table_options"]) + + def to_string(self) -> Dict[str, str]: + """ + Convert TableOptions to a string dictionary. + This is used to encode options for the foreign table definition. + """ + return {"table_options": self.model_dump_json()} + + # Reject any extra fields that are not defined in the model. + model_config = { + "extra": "forbid" + } + + +# Utility functions for Curry hooks + + +def literal(parameters: Dict[str, Any], + command_body: Dict[str, Any]) -> None: + """ + A TableOptions modify_command_body hook. + + Adds the value to the command body under the given name. + This is used to modify the command body before executing it. + """ + command_body.update(parameters) diff --git a/base/docker/scripts/embeddings/embeddings.py b/base/docker/scripts/embeddings/embeddings.py index 95cfad66..62efb53c 100644 --- a/base/docker/scripts/embeddings/embeddings.py +++ b/base/docker/scripts/embeddings/embeddings.py @@ -156,6 +156,23 @@ def get_properties(self) -> dict: "embeddings_fingerprint": self.fingerprint_hash(), } + @classmethod + def check_properties(cls, properties: dict) -> bool: + """ + Check if the properties are valid for this embedder. + This attempts to verify that the properties contain the required keys for an embedder. + + Args: + properties (dict): The properties to check, as returned by `FindDescriptorSet`. + + Returns: + bool: True if the properties are valid for this embedder, False otherwise. + """ + required_keys = ["embeddings_provider", + "embeddings_model", "embeddings_pretrained"] + # TODO: Consider adding more checks, e.g., for the provider and model name. + return all(key in properties for key in required_keys) + @classmethod def from_properties(cls, properties: dict,